diff options
author | Jiayu Liu <jiayu.liu@airbnb.com> | 2022-03-11 04:55:13 +0100 |
---|---|---|
committer | Jens Geyer <jensg@apache.org> | 2022-04-20 23:37:58 +0200 |
commit | eac5103f8204021f7b5436001319c2b17ed5644f (patch) | |
tree | 22610465f5927c466b2e9baa87e55fe116347c6c /lib/java/src/main/java/org/apache/thrift/transport | |
parent | 8987820e84ac26392293ab40480cf8f2971fb314 (diff) | |
download | thrift-eac5103f8204021f7b5436001319c2b17ed5644f.tar.gz |
THRIFT-5545: use gradle convention in organizing java project
Client: java
Patch: Jiayu Liu
This closes #2546
Diffstat (limited to 'lib/java/src/main/java/org/apache/thrift/transport')
54 files changed, 7560 insertions, 0 deletions
diff --git a/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBuffer.java b/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBuffer.java new file mode 100644 index 000000000..b355d11ca --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBuffer.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport; + +import java.util.Arrays; + +/** + * Helper class that wraps a byte[] so that it can expand and be reused. Users + * should call resizeIfNecessary to make sure the buffer has suitable capacity, + * and then use the array as needed. Note that the internal array will grow at a + * rate slightly faster than the requested capacity with the (untested) + * objective of avoiding expensive buffer allocations and copies. + */ +public class AutoExpandingBuffer { + private byte[] array; + + public AutoExpandingBuffer(int initialCapacity) { + this.array = new byte[initialCapacity]; + } + + public void resizeIfNecessary(int size) { + final int currentCapacity = this.array.length; + if (currentCapacity < size) { + // Increase by a factor of 1.5x + int growCapacity = currentCapacity + (currentCapacity >> 1); + int newCapacity = Math.max(growCapacity, size); + this.array = Arrays.copyOf(array, newCapacity); + } + } + + public byte[] array() { + return this.array; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBufferReadTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBufferReadTransport.java new file mode 100644 index 000000000..6fd4075b9 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBufferReadTransport.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +/** + * TTransport for reading from an AutoExpandingBuffer. + */ +public class AutoExpandingBufferReadTransport extends TEndpointTransport { + + private final AutoExpandingBuffer buf; + + private int pos = 0; + private int limit = 0; + + public AutoExpandingBufferReadTransport(TConfiguration config, int initialCapacity) throws TTransportException { + super(config); + this.buf = new AutoExpandingBuffer(initialCapacity); + } + + public void fill(TTransport inTrans, int length) throws TTransportException { + buf.resizeIfNecessary(length); + inTrans.readAll(buf.array(), 0, length); + pos = 0; + limit = length; + } + + @Override + public void close() {} + + @Override + public boolean isOpen() { return true; } + + @Override + public void open() throws TTransportException {} + + @Override + public final int read(byte[] target, int off, int len) throws TTransportException { + int amtToRead = Math.min(len, getBytesRemainingInBuffer()); + if(amtToRead > 0){ + System.arraycopy(buf.array(), pos, target, off, amtToRead); + consumeBuffer(amtToRead); + } + return amtToRead; + } + + @Override + public void write(byte[] buf, int off, int len) throws TTransportException { + throw new UnsupportedOperationException(); + } + + @Override + public final void consumeBuffer(int len) { + pos += len; + } + + @Override + public final byte[] getBuffer() { + return buf.array(); + } + + @Override + public final int getBufferPosition() { + return pos; + } + + @Override + public final int getBytesRemainingInBuffer() { + return limit - pos; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBufferWriteTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBufferWriteTransport.java new file mode 100644 index 000000000..84b28b4c4 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/AutoExpandingBufferWriteTransport.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +/** + * TTransport for writing to an AutoExpandingBuffer. + */ +public final class AutoExpandingBufferWriteTransport extends TEndpointTransport { + + private final AutoExpandingBuffer buf; + private int pos; + private int res; + + /** + * Constructor. + * @param config the configuration to use. Currently used for defining the maximum message size. + * @param initialCapacity the initial capacity of the buffer + * @param frontReserve space, if any, to reserve at the beginning such + * that the first write is after this reserve. + * This allows framed transport to reserve space + * for the frame buffer length. + * @throws IllegalArgumentException if initialCapacity is less than one + * @throws IllegalArgumentException if frontReserve is less than zero + * @throws IllegalArgumentException if frontReserve is greater than initialCapacity + */ + public AutoExpandingBufferWriteTransport(TConfiguration config, int initialCapacity, int frontReserve) throws TTransportException { + super(config); + if (initialCapacity < 1) { + throw new IllegalArgumentException("initialCapacity"); + } + if (frontReserve < 0 || initialCapacity < frontReserve) { + throw new IllegalArgumentException("frontReserve"); + } + this.buf = new AutoExpandingBuffer(initialCapacity); + this.pos = frontReserve; + this.res = frontReserve; + } + + @Override + public void close() {} + + @Override + public boolean isOpen() {return true;} + + @Override + public void open() throws TTransportException {} + + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + throw new UnsupportedOperationException(); + } + + @Override + public void write(byte[] toWrite, int off, int len) throws TTransportException { + buf.resizeIfNecessary(pos + len); + System.arraycopy(toWrite, off, buf.array(), pos, len); + pos += len; + } + + public AutoExpandingBuffer getBuf() { + return buf; + } + + /** + * @return length of the buffer, including any front reserve + */ + public int getLength() { + return pos; + } + + public void reset() { + pos = res; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TByteBuffer.java b/lib/java/src/main/java/org/apache/thrift/transport/TByteBuffer.java new file mode 100644 index 000000000..72d10c53d --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TByteBuffer.java @@ -0,0 +1,104 @@ +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; + +/** + * ByteBuffer-backed implementation of TTransport. + */ +public final class TByteBuffer extends TEndpointTransport { + private final ByteBuffer byteBuffer; + + /** + * Creates a new TByteBuffer wrapping a given NIO ByteBuffer. + * + * @param byteBuffer the NIO ByteBuffer to wrap. + * @throws TTransportException on error. + */ + public TByteBuffer(ByteBuffer byteBuffer) throws TTransportException { + super(new TConfiguration()); + this.byteBuffer = byteBuffer; + updateKnownMessageSize(byteBuffer.capacity()); + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void open() { + } + + @Override + public void close() { + } + + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + // + checkReadBytesAvailable(len); + + final int n = Math.min(byteBuffer.remaining(), len); + if (n > 0) { + try { + byteBuffer.get(buf, off, n); + } catch (BufferUnderflowException e) { + throw new TTransportException("Unexpected end of input buffer", e); + } + } + return n; + } + + @Override + public void write(byte[] buf, int off, int len) throws TTransportException { + try { + byteBuffer.put(buf, off, len); + } catch (BufferOverflowException e) { + throw new TTransportException("Not enough room in output buffer", e); + } + } + + /** + * Gets the underlying NIO ByteBuffer. + * + * @return the underlying NIO ByteBuffer. + */ + public ByteBuffer getByteBuffer() { + return byteBuffer; + } + + /** + * Convenience method to call clear() on the underlying NIO ByteBuffer. + * + * @return this instance. + */ + public TByteBuffer clear() { + byteBuffer.clear(); + return this; + } + + /** + * Convenience method to call flip() on the underlying NIO ByteBuffer. + * + * @return this instance. + */ + public TByteBuffer flip() { + byteBuffer.flip(); + return this; + } + + /** + * Convenience method to convert the underlying NIO ByteBuffer to a plain old byte array. + * + * @return the byte array backing the underlying NIO ByteBuffer. + */ + public byte[] toByteArray() { + final byte[] data = new byte[byteBuffer.remaining()]; + byteBuffer.slice().get(data); + return data; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TEOFException.java b/lib/java/src/main/java/org/apache/thrift/transport/TEOFException.java new file mode 100644 index 000000000..b5ae6eff4 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TEOFException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +/** + * End of file, especially, the underlying socket is closed. + */ +public class TEOFException extends TTransportException { + + public TEOFException(String message) { + super(TTransportException.END_OF_FILE, message); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java new file mode 100644 index 000000000..f33b8b72d --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +import java.util.Objects; + +public abstract class TEndpointTransport extends TTransport{ + + protected long getMaxMessageSize() { return getConfiguration().getMaxMessageSize(); } + + public int getMaxFrameSize() { return getConfiguration().getMaxFrameSize(); } + + public void setMaxFrameSize(int maxFrameSize) { getConfiguration().setMaxFrameSize(maxFrameSize); } + + protected long knownMessageSize; + protected long remainingMessageSize; + + private TConfiguration _configuration; + + public TConfiguration getConfiguration() { + return _configuration; + } + + public TEndpointTransport(TConfiguration config) throws TTransportException { + _configuration = Objects.isNull(config) ? new TConfiguration() : config; + + resetConsumedMessageSize(-1); + } + + /** + * Resets RemainingMessageSize to the configured maximum + * @param newSize + */ + protected void resetConsumedMessageSize(long newSize) throws TTransportException { + // full reset + if (newSize < 0) + { + knownMessageSize = getMaxMessageSize(); + remainingMessageSize = getMaxMessageSize(); + return; + } + + // update only: message size can shrink, but not grow + if (newSize > knownMessageSize) + throw new TTransportException(TTransportException.END_OF_FILE, "MaxMessageSize reached"); + + knownMessageSize = newSize; + remainingMessageSize = newSize; + } + + /** + * Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport). + * Will throw if we already consumed too many bytes or if the new size is larger than allowed. + * @param size + */ + public void updateKnownMessageSize(long size) throws TTransportException { + long consumed = knownMessageSize - remainingMessageSize; + resetConsumedMessageSize(size == 0 ? -1 : size); + countConsumedMessageBytes(consumed); + } + + /** + * Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data + * @param numBytes + */ + public void checkReadBytesAvailable(long numBytes) throws TTransportException { + if (remainingMessageSize < numBytes) + throw new TTransportException(TTransportException.END_OF_FILE, "MaxMessageSize reached"); + } + + /** + * Consumes numBytes from the RemainingMessageSize. + * @param numBytes + */ + protected void countConsumedMessageBytes(long numBytes) throws TTransportException { + if (remainingMessageSize >= numBytes) + { + remainingMessageSize -= numBytes; + } + else + { + remainingMessageSize = 0; + throw new TTransportException(TTransportException.END_OF_FILE, "MaxMessageSize reached"); + } + } + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TFileProcessor.java b/lib/java/src/main/java/org/apache/thrift/transport/TFileProcessor.java new file mode 100644 index 000000000..e36a5f384 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TFileProcessor.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TProcessor; +import org.apache.thrift.TException; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; + +/** + * FileProcessor: helps in processing files generated by TFileTransport. + * Port of original cpp implementation + */ +public class TFileProcessor { + + private TProcessor processor_; + private TProtocolFactory inputProtocolFactory_; + private TProtocolFactory outputProtocolFactory_; + private TFileTransport inputTransport_; + private TTransport outputTransport_; + + public TFileProcessor(TProcessor processor, TProtocolFactory protocolFactory, + TFileTransport inputTransport, + TTransport outputTransport) { + processor_ = processor; + inputProtocolFactory_ = outputProtocolFactory_ = protocolFactory; + inputTransport_ = inputTransport; + outputTransport_ = outputTransport; + } + + public TFileProcessor(TProcessor processor, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + TFileTransport inputTransport, + TTransport outputTransport) { + processor_ = processor; + inputProtocolFactory_ = inputProtocolFactory; + outputProtocolFactory_ = outputProtocolFactory; + inputTransport_ = inputTransport; + outputTransport_ = outputTransport; + } + + private void processUntil(int lastChunk) throws TException { + TProtocol ip = inputProtocolFactory_.getProtocol(inputTransport_); + TProtocol op = outputProtocolFactory_.getProtocol(outputTransport_); + int curChunk = inputTransport_.getCurChunk(); + + try { + while (lastChunk >= curChunk) { + processor_.process(ip, op); + int newChunk = inputTransport_.getCurChunk(); + curChunk = newChunk; + } + } catch (TTransportException e) { + // if we are processing the last chunk - we could have just hit EOF + // on EOF - trap the error and stop processing. + if(e.getType() != TTransportException.END_OF_FILE) + throw e; + else { + return; + } + } + } + + /** + * Process from start to last chunk both inclusive where chunks begin from 0 + + * @param startChunkNum first chunk to be processed + * @param endChunkNum last chunk to be processed + * @throws TException if endChunkNum is less than startChunkNum. + */ + public void processChunk(int startChunkNum, int endChunkNum) throws TException { + int numChunks = inputTransport_.getNumChunks(); + if(endChunkNum < 0) + endChunkNum += numChunks; + + if(startChunkNum < 0) + startChunkNum += numChunks; + + if(endChunkNum < startChunkNum) + throw new TException("endChunkNum " + endChunkNum + " is less than " + startChunkNum); + + inputTransport_.seekToChunk(startChunkNum); + processUntil(endChunkNum); + } + + /** + * Process a single chunk + * + * @param chunkNum chunk to be processed + * @throws TException on error while processing the given chunk. + */ + public void processChunk(int chunkNum) throws TException { + processChunk(chunkNum, chunkNum); + } + + /** + * Process a current chunk + * + * @throws TException on error while processing the given chunk. + */ + public void processChunk() throws TException { + processChunk(inputTransport_.getCurChunk()); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TFileTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TFileTransport.java new file mode 100644 index 000000000..61b68d279 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TFileTransport.java @@ -0,0 +1,642 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.IOException; +import java.util.Random; + +import org.apache.thrift.TConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * FileTransport implementation of the TTransport interface. + * Currently this is a straightforward port of the cpp implementation + * + * It may make better sense to provide a basic stream access on top of the framed file format + * The FileTransport can then be a user of this framed file format with some additional logic + * for chunking. + */ +public class TFileTransport extends TTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TFileTransport.class.getName()); + + public static class TruncableBufferedInputStream extends BufferedInputStream { + public void trunc() { + pos = count = 0; + } + public TruncableBufferedInputStream(InputStream in) { + super(in); + } + public TruncableBufferedInputStream(InputStream in, int size) { + super(in, size); + } + } + + + public static class Event { + private byte[] buf_; + private int nread_; + private int navailable_; + + /** + * Initialize an event. Initially, it has no valid contents + * + * @param buf byte array buffer to store event + */ + public Event(byte[] buf) { + buf_ = buf; + nread_ = navailable_ = 0; + } + + public byte[] getBuf() { return buf_;} + public int getSize() { return buf_.length; } + + + public void setAvailable(int sz) { nread_ = 0; navailable_=sz;} + public int getRemaining() { return (navailable_ - nread_); } + + public int emit(byte[] buf, int offset, int ndesired) { + if((ndesired == 0) || (ndesired > getRemaining())) + ndesired = getRemaining(); + + if(ndesired <= 0) + return (ndesired); + + System.arraycopy(buf_, nread_, buf, offset, ndesired); + nread_ += ndesired; + + return(ndesired); + } + } + + public static class ChunkState { + /** + * Chunk Size. Must be same across all implementations + */ + public static final int DEFAULT_CHUNK_SIZE = 16 * 1024 * 1024; + + private int chunk_size_ = DEFAULT_CHUNK_SIZE; + private long offset_ = 0; + + public ChunkState() {} + public ChunkState(int chunk_size) { chunk_size_ = chunk_size; } + + public void skip(int size) {offset_ += size; } + public void seek(long offset) {offset_ = offset;} + + public int getChunkSize() { return chunk_size_;} + public int getChunkNum() { return ((int)(offset_/chunk_size_));} + public int getRemaining() { return (chunk_size_ - ((int)(offset_ % chunk_size_)));} + public long getOffset() { return (offset_);} + } + + public enum TailPolicy { + + NOWAIT(0, 0), + WAIT_FOREVER(500, -1); + + /** + * Time in milliseconds to sleep before next read + * If 0, no sleep + */ + public final int timeout_; + + /** + * Number of retries before giving up + * if 0, no retries + * if -1, retry forever + */ + public final int retries_; + + /** + * ctor for policy + * + * @param timeout sleep time for this particular policy + * @param retries number of retries + */ + + TailPolicy(int timeout, int retries) { + timeout_ = timeout; + retries_ = retries; + } + } + + /** + * Current tailing policy + */ + TailPolicy currentPolicy_ = TailPolicy.NOWAIT; + + + /** + * Underlying file being read + */ + protected TSeekableFile inputFile_ = null; + + /** + * Underlying outputStream + */ + protected OutputStream outputStream_ = null; + + + /** + * Event currently read in + */ + Event currentEvent_ = null; + + /** + * InputStream currently being used for reading + */ + InputStream inputStream_ = null; + + /** + * current Chunk state + */ + ChunkState cs = null; + + /** + * is read only? + */ + private boolean readOnly_ = false; + + /** + * Get File Tailing Policy + * + * @return current read policy + */ + public TailPolicy getTailPolicy() { + return (currentPolicy_); + } + + /** + * Set file Tailing Policy + * + * @param policy New policy to set + * @return Old policy + */ + public TailPolicy setTailPolicy(TailPolicy policy) { + TailPolicy old = currentPolicy_; + currentPolicy_ = policy; + return (old); + } + + + /** + * Initialize read input stream + * + * @return input stream to read from file + */ + private InputStream createInputStream() throws TTransportException { + InputStream is; + try { + if(inputStream_ != null) { + ((TruncableBufferedInputStream)inputStream_).trunc(); + is = inputStream_; + } else { + is = new TruncableBufferedInputStream(inputFile_.getInputStream()); + } + } catch (IOException iox) { + throw new TTransportException(iox.getMessage(), iox); + } + return(is); + } + + /** + * Read (potentially tailing) an input stream + * + * @param is InputStream to read from + * @param buf Buffer to read into + * @param off Offset in buffer to read into + * @param len Number of bytes to read + * @param tp policy to use if we hit EOF + * + * @return number of bytes read + */ + private int tailRead(InputStream is, byte[] buf, + int off, int len, TailPolicy tp) throws TTransportException { + int orig_len = len; + try { + int retries = 0; + while(len > 0) { + int cnt = is.read(buf, off, len); + if(cnt > 0) { + off += cnt; + len -= cnt; + retries = 0; + cs.skip(cnt); // remember that we read so many bytes + } else if (cnt == -1) { + // EOF + retries++; + + if((tp.retries_ != -1) && tp.retries_ < retries) + return (orig_len - len); + + if(tp.timeout_ > 0) { + try {Thread.sleep(tp.timeout_);} catch(InterruptedException e) {} + } + } else { + // either non-zero or -1 is what the contract says! + throw new + TTransportException("Unexpected return from InputStream.read = " + + cnt); + } + } + } catch (IOException iox) { + throw new TTransportException(iox.getMessage(), iox); + } + + return(orig_len - len); + } + + /** + * Event is corrupted. Do recovery + * + * @return true if recovery could be performed and we can read more data + * false is returned only when nothing more can be read + */ + private boolean performRecovery() throws TTransportException { + int numChunks = getNumChunks(); + int curChunk = cs.getChunkNum(); + + if(curChunk >= (numChunks-1)) { + return false; + } + seekToChunk(curChunk+1); + return true; + } + + /** + * Read event from underlying file + * + * @return true if event could be read, false otherwise (on EOF) + */ + private boolean readEvent() throws TTransportException { + byte[] ebytes = new byte[4]; + int esize; + int nread; + int nrequested; + + retry: + do { + // corner case. read to end of chunk + nrequested = cs.getRemaining(); + if(nrequested < 4) { + nread = tailRead(inputStream_, ebytes, 0, nrequested, currentPolicy_); + if(nread != nrequested) { + return(false); + } + } + + // assuming serialized on little endian machine + nread = tailRead(inputStream_, ebytes, 0, 4, currentPolicy_); + if(nread != 4) { + return(false); + } + + esize=0; + for(int i=3; i>=0; i--) { + int val = (0x000000ff & (int)ebytes[i]); + esize |= (val << (i*8)); + } + + // check if event is corrupted and do recovery as required + if(esize > cs.getRemaining()) { + throw new TTransportException("FileTransport error: bad event size"); + /* + if(performRecovery()) { + esize=0; + } else { + return false; + } + */ + } + } while (esize == 0); + + // reset existing event or get a larger one + if(currentEvent_.getSize() < esize) + currentEvent_ = new Event(new byte [esize]); + + // populate the event + byte[] buf = currentEvent_.getBuf(); + nread = tailRead(inputStream_, buf, 0, esize, currentPolicy_); + if(nread != esize) { + return(false); + } + currentEvent_.setAvailable(esize); + return(true); + } + + /** + * open if both input/output open unless readonly + * + * @return true + */ + public boolean isOpen() { + return ((inputStream_ != null) && (readOnly_ || (outputStream_ != null))); + } + + + /** + * Diverging from the cpp model and sticking to the TSocket model + * Files are not opened in ctor - but in explicit open call + */ + public void open() throws TTransportException { + if (isOpen()) + throw new TTransportException(TTransportException.ALREADY_OPEN); + + try { + inputStream_ = createInputStream(); + cs = new ChunkState(); + currentEvent_ = new Event(new byte [256]); + + if(!readOnly_) + outputStream_ = new BufferedOutputStream(inputFile_.getOutputStream()); + } catch (IOException iox) { + throw new TTransportException(TTransportException.NOT_OPEN, iox); + } + } + + /** + * Closes the transport. + */ + public void close() { + if (inputFile_ != null) { + try { + inputFile_.close(); + } catch (IOException iox) { + LOGGER.warn("WARNING: Error closing input file: " + + iox.getMessage()); + } + inputFile_ = null; + } + if (outputStream_ != null) { + try { + outputStream_.close(); + } catch (IOException iox) { + LOGGER.warn("WARNING: Error closing output stream: " + + iox.getMessage()); + } + outputStream_ = null; + } + } + + + /** + * File Transport ctor + * + * @param path File path to read and write from + * @param readOnly Whether this is a read-only transport + * @throws IOException if there is an error accessing the file. + */ + public TFileTransport(final String path, boolean readOnly) throws IOException { + inputFile_ = new TStandardFile(path); + readOnly_ = readOnly; + } + + /** + * File Transport ctor + * + * @param inputFile open TSeekableFile to read/write from + * @param readOnly Whether this is a read-only transport + */ + public TFileTransport(TSeekableFile inputFile, boolean readOnly) { + inputFile_ = inputFile; + readOnly_ = readOnly; + } + + + /** + * Cloned from TTransport.java:readAll(). Only difference is throwing an EOF exception + * where one is detected. + */ + public int readAll(byte[] buf, int off, int len) + throws TTransportException { + int got = 0; + int ret = 0; + while (got < len) { + ret = read(buf, off+got, len-got); + if (ret < 0) { + throw new TTransportException("Error in reading from file"); + } + if(ret == 0) { + throw new TTransportException(TTransportException.END_OF_FILE, + "End of File reached"); + } + got += ret; + } + return got; + } + + + /** + * Reads up to len bytes into buffer buf, starting at offset off. + * + * @param buf Array to read into + * @param off Index to start reading at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read + * @throws TTransportException if there was an error reading data + */ + public int read(byte[] buf, int off, int len) throws TTransportException { + if(!isOpen()) + throw new TTransportException(TTransportException.NOT_OPEN, + "Must open before reading"); + + if(currentEvent_.getRemaining() == 0) { + if(!readEvent()) + return(0); + } + + int nread = currentEvent_.emit(buf, off, len); + return nread; + } + + public int getNumChunks() throws TTransportException { + if(!isOpen()) + throw new TTransportException(TTransportException.NOT_OPEN, + "Must open before getNumChunks"); + try { + long len = inputFile_.length(); + if(len == 0) + return 0; + else + return (((int)(len/cs.getChunkSize())) + 1); + + } catch (IOException iox) { + throw new TTransportException(iox.getMessage(), iox); + } + } + + public int getCurChunk() throws TTransportException { + if(!isOpen()) + throw new TTransportException(TTransportException.NOT_OPEN, + "Must open before getCurChunk"); + return (cs.getChunkNum()); + + } + + + public void seekToChunk(int chunk) throws TTransportException { + if(!isOpen()) + throw new TTransportException(TTransportException.NOT_OPEN, + "Must open before seeking"); + + int numChunks = getNumChunks(); + + // file is empty, seeking to chunk is pointless + if (numChunks == 0) { + return; + } + + // negative indicates reverse seek (from the end) + if (chunk < 0) { + chunk += numChunks; + } + + // too large a value for reverse seek, just seek to beginning + if (chunk < 0) { + chunk = 0; + } + + long eofOffset=0; + boolean seekToEnd = (chunk >= numChunks); + if(seekToEnd) { + chunk = chunk - 1; + try { eofOffset = inputFile_.length(); } + catch (IOException iox) {throw new TTransportException(iox.getMessage(), + iox);} + } + + if(chunk*cs.getChunkSize() != cs.getOffset()) { + try { inputFile_.seek((long)chunk*cs.getChunkSize()); } + catch (IOException iox) { + throw new TTransportException("Seek to chunk " + + chunk + " " +iox.getMessage(), iox); + } + + cs.seek((long)chunk*cs.getChunkSize()); + currentEvent_.setAvailable(0); + inputStream_ = createInputStream(); + } + + if(seekToEnd) { + // waiting forever here - otherwise we can hit EOF and end up + // having consumed partial data from the data stream. + TailPolicy old = setTailPolicy(TailPolicy.WAIT_FOREVER); + while(cs.getOffset() < eofOffset) { readEvent(); } + currentEvent_.setAvailable(0); + setTailPolicy(old); + } + } + + public void seekToEnd() throws TTransportException { + if(!isOpen()) + throw new TTransportException(TTransportException.NOT_OPEN, + "Must open before seeking"); + seekToChunk(getNumChunks()); + } + + + /** + * Writes up to len bytes from the buffer. + * + * @param buf The output data buffer + * @param off The offset to start writing from + * @param len The number of bytes to write + * @throws TTransportException if there was an error writing data + */ + public void write(byte[] buf, int off, int len) throws TTransportException { + throw new TTransportException("Not Supported"); + } + + /** + * Flush any pending data out of a transport buffer. + * + * @throws TTransportException if there was an error writing out data. + */ + public void flush() throws TTransportException { + throw new TTransportException("Not Supported"); + } + + + @Override + public TConfiguration getConfiguration() { + return null; + } + + @Override + public void updateKnownMessageSize(long size) throws TTransportException { + + } + + @Override + public void checkReadBytesAvailable(long numBytes) throws TTransportException { + + } + + /** + * test program + * + */ + public static void main(String[] args) throws Exception { + + int num_chunks = 10; + + if((args.length < 1) || args[0].equals("--help") + || args[0].equals("-h") || args[0].equals("-?")) { + printUsage(); + } + + if(args.length > 1) { + try { + num_chunks = Integer.parseInt(args[1]); + } catch (Exception e) { + LOGGER.error("Cannot parse " + args[1]); + printUsage(); + } + } + + TFileTransport t = new TFileTransport(args[0], true); + t.open(); + LOGGER.info("NumChunks="+t.getNumChunks()); + + Random r = new Random(); + for(int j=0; j<num_chunks; j++) { + byte[] buf = new byte[4096]; + int cnum = r.nextInt(t.getNumChunks()-1); + LOGGER.info("Reading chunk "+cnum); + t.seekToChunk(cnum); + for(int i=0; i<4096; i++) { + t.read(buf, 0, 4096); + } + } + } + + private static void printUsage() { + LOGGER.error("Usage: TFileTransport <filename> [num_chunks]"); + LOGGER.error(" (Opens and reads num_chunks chunks from file randomly)"); + System.exit(1); + } + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/THttpClient.java b/lib/java/src/main/java/org/apache/thrift/transport/THttpClient.java new file mode 100644 index 000000000..574682248 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/THttpClient.java @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpHost; +import org.apache.http.HttpResponse; +import org.apache.http.HttpStatus; +import org.apache.http.client.HttpClient; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.thrift.TConfiguration; + +/** + * HTTP implementation of the TTransport interface. Used for working with a + * Thrift web services implementation (using for example TServlet). + * + * This class offers two implementations of the HTTP transport. + * One uses HttpURLConnection instances, the other HttpClient from Apache + * Http Components. + * The chosen implementation depends on the constructor used to + * create the THttpClient instance. + * Using the THttpClient(String url) constructor or passing null as the + * HttpClient to THttpClient(String url, HttpClient client) will create an + * instance which will use HttpURLConnection. + * + * When using HttpClient, the following configuration leads to 5-15% + * better performance than the HttpURLConnection implementation: + * + * http.protocol.version=HttpVersion.HTTP_1_1 + * http.protocol.content-charset=UTF-8 + * http.protocol.expect-continue=false + * http.connection.stalecheck=false + * + * Also note that under high load, the HttpURLConnection implementation + * may exhaust the open file descriptor limit. + * + * @see <a href="https://issues.apache.org/jira/browse/THRIFT-970">THRIFT-970</a> + */ + +public class THttpClient extends TEndpointTransport { + + private final URL url_; + + private final ByteArrayOutputStream requestBuffer_ = new ByteArrayOutputStream(); + + private InputStream inputStream_ = null; + + private int connectTimeout_ = 0; + + private int readTimeout_ = 0; + + private Map<String,String> customHeaders_ = null; + + private final HttpHost host; + + private final HttpClient client; + + private static final Map<String, String> DEFAULT_HEADERS = Collections.unmodifiableMap(getDefaultHeaders()); + + public static class Factory extends TTransportFactory { + + private final String url; + private final HttpClient client; + + public Factory(String url) { + this.url = url; + this.client = null; + } + + public Factory(String url, HttpClient client) { + this.url = url; + this.client = client; + } + + @Override + public TTransport getTransport(TTransport trans) { + try { + if (null != client) { + return new THttpClient(trans.getConfiguration(), url, client); + } else { + return new THttpClient(trans.getConfiguration(), url); + } + } catch (TTransportException tte) { + return null; + } + } + } + + public THttpClient(TConfiguration config, String url) throws TTransportException { + super(config); + try { + url_ = new URL(url); + this.client = null; + this.host = null; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public THttpClient(String url) throws TTransportException { + super(new TConfiguration()); + try { + url_ = new URL(url); + this.client = null; + this.host = null; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public THttpClient(TConfiguration config, String url, HttpClient client) throws TTransportException { + super(config); + try { + url_ = new URL(url); + this.client = client; + this.host = new HttpHost(url_.getHost(), -1 == url_.getPort() ? url_.getDefaultPort() : url_.getPort(), url_.getProtocol()); + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public THttpClient(String url, HttpClient client) throws TTransportException { + super(new TConfiguration()); + try { + url_ = new URL(url); + this.client = client; + this.host = new HttpHost(url_.getHost(), -1 == url_.getPort() ? url_.getDefaultPort() : url_.getPort(), url_.getProtocol()); + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void setConnectTimeout(int timeout) { + connectTimeout_ = timeout; + } + + public void setReadTimeout(int timeout) { + readTimeout_ = timeout; + } + + public void setCustomHeaders(Map<String,String> headers) { + customHeaders_ = new HashMap<>(headers); + } + + public void setCustomHeader(String key, String value) { + if (customHeaders_ == null) { + customHeaders_ = new HashMap<>(); + } + customHeaders_.put(key, value); + } + + @Override + public void open() {} + + @Override + public void close() { + if (null != inputStream_) { + try { + inputStream_.close(); + } catch (IOException ioe) { + } + inputStream_ = null; + } + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + if (inputStream_ == null) { + throw new TTransportException("Response buffer is empty, no request."); + } + + checkReadBytesAvailable(len); + + try { + int ret = inputStream_.read(buf, off, len); + if (ret == -1) { + throw new TTransportException("No more data available."); + } + countConsumedMessageBytes(ret); + + return ret; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + @Override + public void write(byte[] buf, int off, int len) { + requestBuffer_.write(buf, off, len); + } + + private RequestConfig getRequestConfig() { + RequestConfig requestConfig = RequestConfig.DEFAULT; + if (connectTimeout_ > 0) { + requestConfig = RequestConfig.copy(requestConfig).setConnectionRequestTimeout(connectTimeout_).build(); + } + if (readTimeout_ > 0) { + requestConfig = RequestConfig.copy(requestConfig).setSocketTimeout(readTimeout_).build(); + } + return requestConfig; + } + + private static Map<String, String> getDefaultHeaders() { + Map<String, String> headers = new HashMap<>(); + headers.put("Content-Type", "application/x-thrift"); + headers.put("Accept", "application/x-thrift"); + headers.put("User-Agent", "Java/THttpClient/HC"); + return headers; + } + + /** + * copy from org.apache.http.util.EntityUtils#consume. Android has it's own httpcore + * that doesn't have a consume. + */ + private static void consume(final HttpEntity entity) throws IOException { + if (entity == null) { + return; + } + if (entity.isStreaming()) { + InputStream instream = entity.getContent(); + if (instream != null) { + instream.close(); + } + } + } + + private void flushUsingHttpClient() throws TTransportException { + if (null == this.client) { + throw new TTransportException("Null HttpClient, aborting."); + } + + // Extract request and reset buffer + byte[] data = requestBuffer_.toByteArray(); + requestBuffer_.reset(); + + HttpPost post = new HttpPost(this.url_.getFile()); + try { + // Set request to path + query string + post.setConfig(getRequestConfig()); + DEFAULT_HEADERS.forEach(post::addHeader); + if (null != customHeaders_) { + customHeaders_.forEach(post::addHeader); + } + post.setEntity(new ByteArrayEntity(data)); + HttpResponse response = this.client.execute(this.host, post); + handleResponse(response); + } catch (IOException ioe) { + // Abort method so the connection gets released back to the connection manager + post.abort(); + throw new TTransportException(ioe); + } finally { + resetConsumedMessageSize(-1); + post.releaseConnection(); + } + } + + private void handleResponse(HttpResponse response) throws TTransportException { + // Retrieve the InputStream BEFORE checking the status code so + // resources get freed in the with clause. + try (InputStream is = response.getEntity().getContent()) { + int responseCode = response.getStatusLine().getStatusCode(); + if (responseCode != HttpStatus.SC_OK) { + throw new TTransportException("HTTP Response code: " + responseCode); + } + byte[] readByteArray = readIntoByteArray(is); + try { + // Indicate we're done with the content. + consume(response.getEntity()); + } catch (IOException ioe) { + // We ignore this exception, it might only mean the server has no + // keep-alive capability. + } + inputStream_ = new ByteArrayInputStream(readByteArray); + } catch (IOException ioe) { + throw new TTransportException(ioe); + } + } + + /** + * Read the responses into a byte array so we can release the connection + * early. This implies that the whole content will have to be read in + * memory, and that momentarily we might use up twice the memory (while the + * thrift struct is being read up the chain). + * Proceeding differently might lead to exhaustion of connections and thus + * to app failure. + * + * @param is input stream + * @return read bytes + * @throws IOException when exception during read + */ + private static byte[] readIntoByteArray(InputStream is) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + byte[] buf = new byte[1024]; + int len; + do { + len = is.read(buf); + if (len > 0) { + baos.write(buf, 0, len); + } + } while (-1 != len); + return baos.toByteArray(); + } + + public void flush() throws TTransportException { + + if (null != this.client) { + flushUsingHttpClient(); + return; + } + + // Extract request and reset buffer + byte[] data = requestBuffer_.toByteArray(); + requestBuffer_.reset(); + + try { + // Create connection object + HttpURLConnection connection = (HttpURLConnection)url_.openConnection(); + + // Timeouts, only if explicitly set + if (connectTimeout_ > 0) { + connection.setConnectTimeout(connectTimeout_); + } + if (readTimeout_ > 0) { + connection.setReadTimeout(readTimeout_); + } + + // Make the request + connection.setRequestMethod("POST"); + connection.setRequestProperty("Content-Type", "application/x-thrift"); + connection.setRequestProperty("Accept", "application/x-thrift"); + connection.setRequestProperty("User-Agent", "Java/THttpClient"); + if (customHeaders_ != null) { + for (Map.Entry<String, String> header : customHeaders_.entrySet()) { + connection.setRequestProperty(header.getKey(), header.getValue()); + } + } + connection.setDoOutput(true); + connection.connect(); + connection.getOutputStream().write(data); + + int responseCode = connection.getResponseCode(); + if (responseCode != HttpURLConnection.HTTP_OK) { + throw new TTransportException("HTTP Response code: " + responseCode); + } + + // Read the responses + inputStream_ = connection.getInputStream(); + + } catch (IOException iox) { + throw new TTransportException(iox); + } finally { + resetConsumedMessageSize(-1); + } + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TIOStreamTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TIOStreamTransport.java new file mode 100644 index 000000000..d5b459c13 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TIOStreamTransport.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.SocketTimeoutException; + +/** + * This is the most commonly used base transport. It takes an InputStream or + * an OutputStream or both and uses it/them to perform transport operations. + * This allows for compatibility with all the nice constructs Java already + * has to provide a variety of types of streams. + * + */ +public class TIOStreamTransport extends TEndpointTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TIOStreamTransport.class.getName()); + + /** Underlying inputStream */ + protected InputStream inputStream_ = null; + + /** Underlying outputStream */ + protected OutputStream outputStream_ = null; + + /** + * Subclasses can invoke the default constructor and then assign the input + * streams in the open method. + */ + protected TIOStreamTransport(TConfiguration config) throws TTransportException { + super(config); + } + + /** + * Subclasses can invoke the default constructor and then assign the input + * streams in the open method. + */ + protected TIOStreamTransport() throws TTransportException { + super(new TConfiguration()); + } + + /** + * Input stream constructor, constructs an input only transport. + * + * @param config + * @param is Input stream to read from + */ + public TIOStreamTransport(TConfiguration config, InputStream is) throws TTransportException { + super(config); + inputStream_ = is; + } + /** + * Input stream constructor, constructs an input only transport. + * + * @param is Input stream to read from + */ + public TIOStreamTransport(InputStream is) throws TTransportException { + super(new TConfiguration()); + inputStream_ = is; + } + + /** + * Output stream constructor, constructs an output only transport. + * + * @param config + * @param os Output stream to write to + */ + public TIOStreamTransport(TConfiguration config, OutputStream os) throws TTransportException { + super(config); + outputStream_ = os; + } + + /** + * Output stream constructor, constructs an output only transport. + * + * @param os Output stream to write to + */ + public TIOStreamTransport(OutputStream os) throws TTransportException { + super(new TConfiguration()); + outputStream_ = os; + } + + /** + * Two-way stream constructor. + * + * @param config + * @param is Input stream to read from + * @param os Output stream to read from + */ + public TIOStreamTransport(TConfiguration config, InputStream is, OutputStream os) throws TTransportException { + super(config); + inputStream_ = is; + outputStream_ = os; + } + + /** + * Two-way stream constructor. + * + * @param is Input stream to read from + * @param os Output stream to read from + */ + public TIOStreamTransport(InputStream is, OutputStream os) throws TTransportException { + super(new TConfiguration()); + inputStream_ = is; + outputStream_ = os; + } + + /** + * + * @return false after close is called. + */ + public boolean isOpen() { + return inputStream_ != null || outputStream_ != null; + } + + /** + * The streams must already be open. This method does nothing. + */ + public void open() throws TTransportException {} + + /** + * Closes both the input and output streams. + */ + public void close() { + try { + if (inputStream_ != null) { + try { + inputStream_.close(); + } catch (IOException iox) { + LOGGER.warn("Error closing input stream.", iox); + } + } + if (outputStream_ != null) { + try { + outputStream_.close(); + } catch (IOException iox) { + LOGGER.warn("Error closing output stream.", iox); + } + } + } finally { + inputStream_ = null; + outputStream_ = null; + } + } + + /** + * Reads from the underlying input stream if not null. + */ + public int read(byte[] buf, int off, int len) throws TTransportException { + if (inputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot read from null inputStream"); + } + int bytesRead; + try { + bytesRead = inputStream_.read(buf, off, len); + } catch (SocketTimeoutException ste) { + throw new TTransportException(TTransportException.TIMED_OUT, ste); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + if (bytesRead < 0) { + throw new TTransportException(TTransportException.END_OF_FILE, "Socket is closed by peer."); + } + return bytesRead; + } + + /** + * Writes to the underlying output stream if not null. + */ + public void write(byte[] buf, int off, int len) throws TTransportException { + if (outputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot write to null outputStream"); + } + try { + outputStream_.write(buf, off, len); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Flushes the underlying output stream if not null. + */ + public void flush() throws TTransportException { + if (outputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot flush null outputStream"); + } + try { + outputStream_.flush(); + + resetConsumedMessageSize(-1); + + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TMemoryBuffer.java b/lib/java/src/main/java/org/apache/thrift/transport/TMemoryBuffer.java new file mode 100644 index 000000000..d9a3cc928 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TMemoryBuffer.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.TConfiguration; + +import java.nio.charset.Charset; + +/** + * Memory buffer-based implementation of the TTransport interface. + */ +public class TMemoryBuffer extends TEndpointTransport { + /** + * Create a TMemoryBuffer with an initial buffer size of <i>size</i>. The + * internal buffer will grow as necessary to accommodate the size of the data + * being written to it. + * + * @param size the initial size of the buffer + * @throws TTransportException on error initializing the underlying transport. + */ + public TMemoryBuffer(int size) throws TTransportException { + super(new TConfiguration()); + arr_ = new TByteArrayOutputStream(size); + updateKnownMessageSize(size); + } + + /** + * Create a TMemoryBuffer with an initial buffer size of <i>size</i>. The + * internal buffer will grow as necessary to accommodate the size of the data + * being written to it. + * + * @param config the configuration to use. + * @param size the initial size of the buffer + * @throws TTransportException on error initializing the underlying transport. + */ + public TMemoryBuffer(TConfiguration config, int size) throws TTransportException { + super(config); + arr_ = new TByteArrayOutputStream(size); + updateKnownMessageSize(size); + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void open() { + /* Do nothing */ + } + + @Override + public void close() { + /* Do nothing */ + } + + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + checkReadBytesAvailable(len); + byte[] src = arr_.get(); + int amtToRead = (len > arr_.len() - pos_ ? arr_.len() - pos_ : len); + + if (amtToRead > 0) { + System.arraycopy(src, pos_, buf, off, amtToRead); + pos_ += amtToRead; + } + return amtToRead; + } + + @Override + public void write(byte[] buf, int off, int len) { + arr_.write(buf, off, len); + } + + /** + * Output the contents of the memory buffer as a String, using the supplied + * encoding + * @param charset the encoding to use + * @return the contents of the memory buffer as a String + */ + public String toString(Charset charset) { + return arr_.toString(charset); + } + + public String inspect() { + StringBuilder buf = new StringBuilder(); + byte[] bytes = arr_.toByteArray(); + for (int i = 0; i < bytes.length; i++) { + buf.append(pos_ == i ? "==>" : "" ).append(Integer.toHexString(bytes[i] & 0xff)).append(" "); + } + return buf.toString(); + } + + // The contents of the buffer + private TByteArrayOutputStream arr_; + + // Position to read next byte from + private int pos_; + + public int length() { + return arr_.size(); + } + + public byte[] getArray() { + return arr_.get(); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TMemoryInputTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TMemoryInputTransport.java new file mode 100644 index 000000000..6cb06fc37 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TMemoryInputTransport.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +public final class TMemoryInputTransport extends TEndpointTransport { + + private byte[] buf_; + private int pos_; + private int endPos_; + + public TMemoryInputTransport() throws TTransportException { + this(new TConfiguration()); + } + + public TMemoryInputTransport(TConfiguration _configuration) throws TTransportException { + this(_configuration, new byte[0]); + } + + public TMemoryInputTransport(byte[] buf) throws TTransportException { + this(new TConfiguration(), buf); + } + + public TMemoryInputTransport(TConfiguration _configuration, byte[] buf) throws TTransportException { + this(_configuration, buf, 0, buf.length); + } + + public TMemoryInputTransport(byte[] buf, int offset, int length) throws TTransportException { + this(new TConfiguration(), buf, offset, length); + } + + public TMemoryInputTransport(TConfiguration _configuration, byte[] buf, int offset, int length) throws TTransportException { + super(_configuration); + reset(buf, offset, length); + updateKnownMessageSize(length); + } + + public void reset(byte[] buf) { + reset(buf, 0, buf.length); + } + + public void reset(byte[] buf, int offset, int length) { + buf_ = buf; + pos_ = offset; + endPos_ = offset + length; + try { + resetConsumedMessageSize(-1); + } catch (TTransportException e) { + // ignore + } + } + + public void clear() { + buf_ = null; + try { + resetConsumedMessageSize(-1); + } catch (TTransportException e) { + // ignore + } + } + + @Override + public void close() {} + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void open() throws TTransportException {} + + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + int bytesRemaining = getBytesRemainingInBuffer(); + int amtToRead = (len > bytesRemaining ? bytesRemaining : len); + if (amtToRead > 0) { + System.arraycopy(buf_, pos_, buf, off, amtToRead); + consumeBuffer(amtToRead); + countConsumedMessageBytes(amtToRead); + } + return amtToRead; + } + + @Override + public void write(byte[] buf, int off, int len) throws TTransportException { + throw new UnsupportedOperationException("No writing allowed!"); + } + + @Override + public byte[] getBuffer() { + return buf_; + } + + public int getBufferPosition() { + return pos_; + } + + public int getBytesRemainingInBuffer() { + return endPos_ - pos_; + } + + public void consumeBuffer(int len) { + pos_ += len; + } + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TMemoryTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TMemoryTransport.java new file mode 100644 index 000000000..0172ca816 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TMemoryTransport.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.nio.ByteBuffer; + +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.TConfiguration; + +/** + * In memory transport with separate buffers for input and output. + */ +public class TMemoryTransport extends TEndpointTransport { + + private final ByteBuffer inputBuffer; + private final TByteArrayOutputStream outputBuffer; + + public TMemoryTransport(byte[] input) throws TTransportException { + super(new TConfiguration()); + inputBuffer = ByteBuffer.wrap(input); + outputBuffer = new TByteArrayOutputStream(1024); + updateKnownMessageSize(input.length); + } + + public TMemoryTransport(TConfiguration config, byte[] input) throws TTransportException { + super(config); + inputBuffer = ByteBuffer.wrap(input); + outputBuffer = new TByteArrayOutputStream(1024); + updateKnownMessageSize(input.length); + } + + @Override + public boolean isOpen() { + return true; + } + + /** + * Opening on an in memory transport should have no effect. + */ + @Override + public void open() { + // Do nothing. + } + + @Override + public void close() { + // Do nothing. + } + + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + checkReadBytesAvailable(len); + int remaining = inputBuffer.remaining(); + if (remaining < len) { + throw new TTransportException(TTransportException.END_OF_FILE, + "There's only " + remaining + "bytes, but it asks for " + len); + } + inputBuffer.get(buf, off, len); + return len; + } + + @Override + public void write(byte[] buf, int off, int len) throws TTransportException { + outputBuffer.write(buf, off, len); + } + + /** + * Get all the bytes written by thrift output protocol. + * + * @return a byte array. + */ + public TByteArrayOutputStream getOutput() { + return outputBuffer; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java new file mode 100644 index 000000000..535fd6f51 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.SocketException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; + +import org.apache.thrift.TConfiguration; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Wrapper around ServerSocketChannel + */ +public class TNonblockingServerSocket extends TNonblockingServerTransport { + private static final Logger LOGGER = LoggerFactory.getLogger(TNonblockingServerSocket.class.getName()); + + /** + * This channel is where all the nonblocking magic happens. + */ + private ServerSocketChannel serverSocketChannel = null; + + /** + * Underlying ServerSocket object + */ + private ServerSocket serverSocket_ = null; + + /** + * Timeout for client sockets from accept + */ + private int clientTimeout_ = 0; + + /** + * Limit for client sockets request size + */ + private int maxFrameSize_ = 0; + + public static class NonblockingAbstractServerSocketArgs extends + AbstractServerTransportArgs<NonblockingAbstractServerSocketArgs> {} + + /** + * Creates just a port listening server socket + */ + public TNonblockingServerSocket(int port) throws TTransportException { + this(port, 0); + } + + /** + * Creates just a port listening server socket + */ + public TNonblockingServerSocket(int port, int clientTimeout) throws TTransportException { + this(port, clientTimeout, TConfiguration.DEFAULT_MAX_FRAME_SIZE); + } + + public TNonblockingServerSocket(int port, int clientTimeout, int maxFrameSize) throws TTransportException { + this(new NonblockingAbstractServerSocketArgs().port(port).clientTimeout(clientTimeout).maxFrameSize(maxFrameSize)); + } + + public TNonblockingServerSocket(InetSocketAddress bindAddr) throws TTransportException { + this(bindAddr, 0); + } + + public TNonblockingServerSocket(InetSocketAddress bindAddr, int clientTimeout) throws TTransportException { + this(bindAddr, clientTimeout, TConfiguration.DEFAULT_MAX_FRAME_SIZE); + } + + public TNonblockingServerSocket(InetSocketAddress bindAddr, int clientTimeout, int maxFrameSize) throws TTransportException { + this(new NonblockingAbstractServerSocketArgs().bindAddr(bindAddr).clientTimeout(clientTimeout).maxFrameSize(maxFrameSize)); + } + + public TNonblockingServerSocket(NonblockingAbstractServerSocketArgs args) throws TTransportException { + clientTimeout_ = args.clientTimeout; + maxFrameSize_ = args.maxFrameSize; + try { + serverSocketChannel = ServerSocketChannel.open(); + serverSocketChannel.configureBlocking(false); + + // Make server socket + serverSocket_ = serverSocketChannel.socket(); + // Prevent 2MSL delay problem on server restarts + serverSocket_.setReuseAddress(true); + // Bind to listening port + serverSocket_.bind(args.bindAddr, args.backlog); + } catch (IOException ioe) { + serverSocket_ = null; + throw new TTransportException("Could not create ServerSocket on address " + args.bindAddr.toString() + ".", ioe); + } + } + + public void listen() throws TTransportException { + // Make sure not to block on accept + if (serverSocket_ != null) { + try { + serverSocket_.setSoTimeout(0); + } catch (SocketException sx) { + LOGGER.error("Socket exception while setting socket timeout", sx); + } + } + } + + @Override + public TNonblockingSocket accept() throws TTransportException { + if (serverSocket_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket."); + } + try { + SocketChannel socketChannel = serverSocketChannel.accept(); + if (socketChannel == null) { + return null; + } + + TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel); + tsocket.setTimeout(clientTimeout_); + tsocket.setMaxFrameSize(maxFrameSize_); + return tsocket; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void registerSelector(Selector selector) { + try { + // Register the server socket channel, indicating an interest in + // accepting new connections + serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT); + } catch (ClosedChannelException e) { + // this shouldn't happen, ideally... + // TODO: decide what to do with this. + } + } + + public void close() { + if (serverSocket_ != null) { + try { + serverSocket_.close(); + } catch (IOException iox) { + LOGGER.warn("WARNING: Could not close server socket: " + iox.getMessage()); + } + serverSocket_ = null; + } + } + + public void interrupt() { + // The thread-safeness of this is dubious, but Java documentation suggests + // that it is safe to do this from a different thread context + close(); + } + + public int getPort() { + if (serverSocket_ == null) + return -1; + return serverSocket_.getLocalPort(); + } + + // Expose it for test purpose. + ServerSocketChannel getServerSocketChannel() { + return serverSocketChannel; + } + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerTransport.java new file mode 100644 index 000000000..53d084281 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerTransport.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.nio.channels.Selector; + +/** + * Server transport that can be operated in a nonblocking fashion. + */ +public abstract class TNonblockingServerTransport extends TServerTransport { + + public abstract void registerSelector(Selector selector); + + /** + * + * @return an incoming connection or null if there is none. + * @throws TTransportException on error during this operation. + */ + @Override + public abstract TNonblockingTransport accept() throws TTransportException; +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingSocket.java new file mode 100644 index 000000000..13c858648 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingSocket.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; + +import org.apache.thrift.TConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Transport for use with async client. + */ +public class TNonblockingSocket extends TNonblockingTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TNonblockingSocket.class.getName()); + + /** + * Host and port if passed in, used for lazy non-blocking connect. + */ + private final SocketAddress socketAddress_; + + private final SocketChannel socketChannel_; + + public TNonblockingSocket(String host, int port) throws IOException, TTransportException { + this(host, port, 0); + } + + /** + * Create a new nonblocking socket transport that will be connected to host:port. + * @param host + * @param port + * @throws IOException + */ + public TNonblockingSocket(String host, int port, int timeout) throws IOException, TTransportException { + this(SocketChannel.open(), timeout, new InetSocketAddress(host, port)); + } + + /** + * Constructor that takes an already created socket. + * + * @param socketChannel Already created SocketChannel object + * @throws IOException if there is an error setting up the streams + */ + public TNonblockingSocket(SocketChannel socketChannel) throws IOException, TTransportException { + this(socketChannel, 0, null); + if (!socketChannel.isConnected()) throw new IOException("Socket must already be connected"); + } + + private TNonblockingSocket(SocketChannel socketChannel, int timeout, SocketAddress socketAddress) + throws IOException, TTransportException { + this(new TConfiguration(), socketChannel, timeout, socketAddress); + } + + private TNonblockingSocket(TConfiguration config, SocketChannel socketChannel, int timeout, SocketAddress socketAddress) + throws IOException, TTransportException { + super(config); + socketChannel_ = socketChannel; + socketAddress_ = socketAddress; + + // make it a nonblocking channel + socketChannel.configureBlocking(false); + + // set options + Socket socket = socketChannel.socket(); + socket.setSoLinger(false, 0); + socket.setTcpNoDelay(true); + socket.setKeepAlive(true); + setTimeout(timeout); + } + + /** + * Register the new SocketChannel with our Selector, indicating + * we'd like to be notified when it's ready for I/O. + * + * @param selector + * @return the selection key for this socket. + */ + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + return socketChannel_.register(selector, interests); + } + + /** + * Sets the socket timeout, although this implementation never uses blocking operations so it is unused. + * + * @param timeout Milliseconds timeout + */ + public void setTimeout(int timeout) { + try { + socketChannel_.socket().setSoTimeout(timeout); + } catch (SocketException sx) { + LOGGER.warn("Could not set socket timeout.", sx); + } + } + + /** + * Returns a reference to the underlying SocketChannel. + */ + public SocketChannel getSocketChannel() { + return socketChannel_; + } + + /** + * Checks whether the socket is connected. + */ + public boolean isOpen() { + // isConnected() does not return false after close(), but isOpen() does + return socketChannel_.isOpen() && socketChannel_.isConnected(); + } + + /** + * Do not call, the implementation provides its own lazy non-blocking connect. + */ + public void open() throws TTransportException { + throw new RuntimeException("open() is not implemented for TNonblockingSocket"); + } + + /** + * Perform a nonblocking read into buffer. + */ + public int read(ByteBuffer buffer) throws TTransportException { + try { + return socketChannel_.read(buffer); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Reads from the underlying input stream if not null. + */ + public int read(byte[] buf, int off, int len) throws TTransportException { + if ((socketChannel_.validOps() & SelectionKey.OP_READ) != SelectionKey.OP_READ) { + throw new TTransportException(TTransportException.NOT_OPEN, + "Cannot read from write-only socket channel"); + } + try { + return socketChannel_.read(ByteBuffer.wrap(buf, off, len)); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Perform a nonblocking write of the data in buffer; + */ + public int write(ByteBuffer buffer) throws TTransportException { + try { + return socketChannel_.write(buffer); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Writes to the underlying output stream if not null. + */ + public void write(byte[] buf, int off, int len) throws TTransportException { + if ((socketChannel_.validOps() & SelectionKey.OP_WRITE) != SelectionKey.OP_WRITE) { + throw new TTransportException(TTransportException.NOT_OPEN, + "Cannot write to write-only socket channel"); + } + write(ByteBuffer.wrap(buf, off, len)); + } + + /** + * Noop. + */ + public void flush() throws TTransportException { + // Not supported by SocketChannel. + } + + /** + * Closes the socket. + */ + public void close() { + try { + socketChannel_.close(); + } catch (IOException iox) { + LOGGER.warn("Could not close socket.", iox); + } + } + + /** {@inheritDoc} */ + public boolean startConnect() throws IOException { + return socketChannel_.connect(socketAddress_); + } + + /** {@inheritDoc} */ + public boolean finishConnect() throws IOException { + return socketChannel_.finishConnect(); + } + + @Override + public String toString() { + return "[remote: " + socketChannel_.socket().getRemoteSocketAddress() + + ", local: " + socketChannel_.socket().getLocalAddress() + "]" ; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingTransport.java new file mode 100644 index 000000000..30ec9d25c --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingTransport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; + +public abstract class TNonblockingTransport extends TEndpointTransport { + + public TNonblockingTransport(TConfiguration config) throws TTransportException { + super(config); + } + + /** + * Non-blocking connection initialization. + * @see java.nio.channels.SocketChannel#connect(SocketAddress remote) + */ + public abstract boolean startConnect() throws IOException; + + /** + * Non-blocking connection completion. + * @see java.nio.channels.SocketChannel#finishConnect() + */ + public abstract boolean finishConnect() throws IOException; + + public abstract SelectionKey registerSelector(Selector selector, int interests) throws IOException; + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TSSLTransportFactory.java b/lib/java/src/main/java/org/apache/thrift/transport/TSSLTransportFactory.java new file mode 100644 index 000000000..3389e4d2a --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TSSLTransportFactory.java @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.InputStream; +import java.io.IOException; +import java.net.InetAddress; +import java.net.URL; +import java.net.MalformedURLException; +import java.security.KeyStore; +import java.util.Arrays; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLServerSocketFactory; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Factory for providing and setting up Client and Server SSL wrapped + * TSocket and TServerSocket + */ +public class TSSLTransportFactory { + + private static final Logger LOGGER = + LoggerFactory.getLogger(TSSLTransportFactory.class); + + /** + * Get a SSL wrapped TServerSocket bound to the specified port. In this + * configuration the default settings are used. Default settings are retrieved + * from System properties that are set. + * + * Example system properties: + * -Djavax.net.ssl.trustStore=<truststore location> + * -Djavax.net.ssl.trustStorePassword=password + * -Djavax.net.ssl.keyStore=<keystore location> + * -Djavax.net.ssl.keyStorePassword=password + * + * @param port + * @return A SSL wrapped TServerSocket + * @throws TTransportException + */ + public static TServerSocket getServerSocket(int port) throws TTransportException { + return getServerSocket(port, 0); + } + + /** + * Get a default SSL wrapped TServerSocket bound to the specified port + * + * @param port + * @param clientTimeout + * @return A SSL wrapped TServerSocket + * @throws TTransportException + */ + public static TServerSocket getServerSocket(int port, int clientTimeout) throws TTransportException { + return getServerSocket(port, clientTimeout, false, null); + } + + /** + * Get a default SSL wrapped TServerSocket bound to the specified port and interface + * + * @param port + * @param clientTimeout + * @param ifAddress + * @return A SSL wrapped TServerSocket + * @throws TTransportException + */ + public static TServerSocket getServerSocket(int port, int clientTimeout, boolean clientAuth, InetAddress ifAddress) throws TTransportException { + SSLServerSocketFactory factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault(); + return createServer(factory, port, clientTimeout, clientAuth, ifAddress, null); + } + + /** + * Get a configured SSL wrapped TServerSocket bound to the specified port and interface. + * Here the TSSLTransportParameters are used to set the values for the algorithms, keystore, + * truststore and other settings + * + * @param port + * @param clientTimeout + * @param ifAddress + * @param params + * @return A SSL wrapped TServerSocket + * @throws TTransportException + */ + public static TServerSocket getServerSocket(int port, int clientTimeout, InetAddress ifAddress, TSSLTransportParameters params) throws TTransportException { + if (params == null || !(params.isKeyStoreSet || params.isTrustStoreSet)) { + throw new TTransportException("Either one of the KeyStore or TrustStore must be set for SSLTransportParameters"); + } + + SSLContext ctx = createSSLContext(params); + return createServer(ctx.getServerSocketFactory(), port, clientTimeout, params.clientAuth, ifAddress, params); + } + + private static TServerSocket createServer(SSLServerSocketFactory factory, int port, int timeout, boolean clientAuth, + InetAddress ifAddress, TSSLTransportParameters params) throws TTransportException { + try { + SSLServerSocket serverSocket = (SSLServerSocket) factory.createServerSocket(port, 100, ifAddress); + serverSocket.setSoTimeout(timeout); + serverSocket.setNeedClientAuth(clientAuth); + if (params != null && params.cipherSuites != null) { + serverSocket.setEnabledCipherSuites(params.cipherSuites); + } + return new TServerSocket(new TServerSocket.ServerSocketTransportArgs(). + serverSocket(serverSocket).clientTimeout(timeout)); + } catch (Exception e) { + throw new TTransportException("Could not bind to port " + port, e); + } + } + + /** + * Get a default SSL wrapped TSocket connected to the specified host and port. All + * the client methods return a bound connection. So there is no need to call open() on the + * TTransport. + * + * @param host + * @param port + * @param timeout + * @return A SSL wrapped TSocket + * @throws TTransportException + */ + public static TSocket getClientSocket(String host, int port, int timeout) throws TTransportException { + SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault(); + return createClient(factory, host, port, timeout); + } + + /** + * Get a default SSL wrapped TSocket connected to the specified host and port. + * + * @param host + * @param port + * @return A SSL wrapped TSocket + * @throws TTransportException + */ + public static TSocket getClientSocket(String host, int port) throws TTransportException { + return getClientSocket(host, port, 0); + } + + /** + * Get a custom configured SSL wrapped TSocket. The SSL settings are obtained from the + * passed in TSSLTransportParameters. + * + * @param host + * @param port + * @param timeout + * @param params + * @return A SSL wrapped TSocket + * @throws TTransportException + */ + public static TSocket getClientSocket(String host, int port, int timeout, TSSLTransportParameters params) throws TTransportException { + if (params == null || !(params.isKeyStoreSet || params.isTrustStoreSet)) { + throw new TTransportException(TTransportException.NOT_OPEN, "Either one of the KeyStore or TrustStore must be set for SSLTransportParameters"); + } + + SSLContext ctx = createSSLContext(params); + return createClient(ctx.getSocketFactory(), host, port, timeout); + } + + private static SSLContext createSSLContext(TSSLTransportParameters params) throws TTransportException { + SSLContext ctx; + InputStream in = null; + InputStream is = null; + + try { + ctx = SSLContext.getInstance(params.protocol); + TrustManagerFactory tmf = null; + KeyManagerFactory kmf = null; + + if (params.isTrustStoreSet) { + tmf = TrustManagerFactory.getInstance(params.trustManagerType); + KeyStore ts = KeyStore.getInstance(params.trustStoreType); + if (params.trustStoreStream != null) { + in = params.trustStoreStream; + } else { + in = getStoreAsStream(params.trustStore); + } + ts.load(in, + (params.trustPass != null ? params.trustPass.toCharArray() : null)); + tmf.init(ts); + } + + if (params.isKeyStoreSet) { + kmf = KeyManagerFactory.getInstance(params.keyManagerType); + KeyStore ks = KeyStore.getInstance(params.keyStoreType); + if (params.keyStoreStream != null) { + is = params.keyStoreStream; + } else { + is = getStoreAsStream(params.keyStore); + } + ks.load(is, params.keyPass.toCharArray()); + kmf.init(ks, params.keyPass.toCharArray()); + } + + if (params.isKeyStoreSet && params.isTrustStoreSet) { + ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + } + else if (params.isKeyStoreSet) { + ctx.init(kmf.getKeyManagers(), null, null); + } + else { + ctx.init(null, tmf.getTrustManagers(), null); + } + + } catch (Exception e) { + throw new TTransportException(TTransportException.NOT_OPEN, "Error creating the transport", e); + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + if (is != null) { + try { + is.close(); + } catch (IOException e) { + LOGGER.warn("Unable to close stream", e); + } + } + } + + return ctx; + } + + private static InputStream getStoreAsStream(String store) throws IOException { + try { + return new FileInputStream(store); + } catch(FileNotFoundException e) { + } + + InputStream storeStream = null; + try { + storeStream = new URL(store).openStream(); + if (storeStream != null) { + return storeStream; + } + } catch(MalformedURLException e) { + } + + storeStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(store); + + if (storeStream != null) { + return storeStream; + } else { + throw new IOException("Could not load file: " + store); + } + } + + private static TSocket createClient(SSLSocketFactory factory, String host, int port, int timeout) throws TTransportException { + try { + SSLSocket socket = (SSLSocket) factory.createSocket(host, port); + socket.setSoTimeout(timeout); + return new TSocket(socket); + } catch (TTransportException tte) { + throw tte; + } catch (Exception e) { + throw new TTransportException(TTransportException.NOT_OPEN, "Could not connect to " + host + " on port " + port, e); + } + } + + + /** + * A Class to hold all the SSL parameters + */ + public static class TSSLTransportParameters { + protected String protocol = "TLS"; + protected String keyStore; + protected InputStream keyStoreStream; + protected String keyPass; + protected String keyManagerType = KeyManagerFactory.getDefaultAlgorithm(); + protected String keyStoreType = "JKS"; + protected String trustStore; + protected InputStream trustStoreStream; + protected String trustPass; + protected String trustManagerType = TrustManagerFactory.getDefaultAlgorithm(); + protected String trustStoreType = "JKS"; + protected String[] cipherSuites; + protected boolean clientAuth = false; + protected boolean isKeyStoreSet = false; + protected boolean isTrustStoreSet = false; + + public TSSLTransportParameters() {} + + /** + * Create parameters specifying the protocol and cipher suites + * + * @param protocol The specific protocol (TLS/SSL) can be specified with versions + * @param cipherSuites + */ + public TSSLTransportParameters(String protocol, String[] cipherSuites) { + this(protocol, cipherSuites, false); + } + + /** + * Create parameters specifying the protocol, cipher suites and if client authentication + * is required + * + * @param protocol The specific protocol (TLS/SSL) can be specified with versions + * @param cipherSuites + * @param clientAuth + */ + public TSSLTransportParameters(String protocol, String[] cipherSuites, boolean clientAuth) { + if (protocol != null) { + this.protocol = protocol; + } + this.cipherSuites = cipherSuites != null ? Arrays.copyOf(cipherSuites, cipherSuites.length) : null; + this.clientAuth = clientAuth; + } + + /** + * Set the keystore, password, certificate type and the store type + * + * @param keyStore Location of the Keystore on disk + * @param keyPass Keystore password + * @param keyManagerType The default is X509 + * @param keyStoreType The default is JKS + */ + public void setKeyStore(String keyStore, String keyPass, String keyManagerType, String keyStoreType) { + this.keyStore = keyStore; + this.keyPass = keyPass; + if (keyManagerType != null) { + this.keyManagerType = keyManagerType; + } + if (keyStoreType != null) { + this.keyStoreType = keyStoreType; + } + isKeyStoreSet = true; + } + + /** + * Set the keystore, password, certificate type and the store type + * + * @param keyStoreStream Keystore content input stream + * @param keyPass Keystore password + * @param keyManagerType The default is X509 + * @param keyStoreType The default is JKS + */ + public void setKeyStore(InputStream keyStoreStream, String keyPass, String keyManagerType, String keyStoreType) { + this.keyStoreStream = keyStoreStream; + setKeyStore("", keyPass, keyManagerType, keyStoreType); + } + + /** + * Set the keystore and password + * + * @param keyStore Location of the Keystore on disk + * @param keyPass Keystore password + */ + public void setKeyStore(String keyStore, String keyPass) { + setKeyStore(keyStore, keyPass, null, null); + } + + /** + * Set the keystore and password + * + * @param keyStoreStream Keystore content input stream + * @param keyPass Keystore password + */ + public void setKeyStore(InputStream keyStoreStream, String keyPass) { + setKeyStore(keyStoreStream, keyPass, null, null); + } + + /** + * Set the truststore, password, certificate type and the store type + * + * @param trustStore Location of the Truststore on disk + * @param trustPass Truststore password + * @param trustManagerType The default is X509 + * @param trustStoreType The default is JKS + */ + public void setTrustStore(String trustStore, String trustPass, String trustManagerType, String trustStoreType) { + this.trustStore = trustStore; + this.trustPass = trustPass; + if (trustManagerType != null) { + this.trustManagerType = trustManagerType; + } + if (trustStoreType != null) { + this.trustStoreType = trustStoreType; + } + isTrustStoreSet = true; + } + + /** + * Set the truststore, password, certificate type and the store type + * + * @param trustStoreStream Truststore content input stream + * @param trustPass Truststore password + * @param trustManagerType The default is X509 + * @param trustStoreType The default is JKS + */ + public void setTrustStore(InputStream trustStoreStream, String trustPass, String trustManagerType, String trustStoreType) { + this.trustStoreStream = trustStoreStream; + setTrustStore("", trustPass, trustManagerType, trustStoreType); + } + + /** + * Set the truststore and password + * + * @param trustStore Location of the Truststore on disk + * @param trustPass Truststore password + */ + public void setTrustStore(String trustStore, String trustPass) { + setTrustStore(trustStore, trustPass, null, null); + } + + /** + * Set the truststore and password + * + * @param trustStoreStream Truststore content input stream + * @param trustPass Truststore password + */ + public void setTrustStore(InputStream trustStoreStream, String trustPass) { + setTrustStore(trustStoreStream, trustPass, null, null); + } + + /** + * Set if client authentication is required + * + * @param clientAuth + */ + public void requireClientAuth(boolean clientAuth) { + this.clientAuth = clientAuth; + } + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TSaslClientTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TSaslClientTransport.java new file mode 100644 index 000000000..e5ca41831 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TSaslClientTransport.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; + +import org.apache.thrift.transport.sasl.NegotiationStatus; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Wraps another Thrift <code>TTransport</code>, but performs SASL client + * negotiation on the call to <code>open()</code>. This class will wrap ensuing + * communication over it, if a SASL QOP is negotiated with the other party. + */ +public class TSaslClientTransport extends TSaslTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslClientTransport.class); + + /** + * The name of the mechanism this client supports. + */ + private final String mechanism; + + /** + * Uses the given <code>SaslClient</code>. + * + * @param saslClient + * The <code>SaslClient</code> to use for the subsequent SASL + * negotiation. + * @param transport + * Transport underlying this one. + */ + public TSaslClientTransport(SaslClient saslClient, TTransport transport) throws TTransportException { + super(saslClient, transport); + mechanism = saslClient.getMechanismName(); + } + + /** + * Creates a <code>SaslClient</code> using the given SASL-specific parameters. + * See the Java documentation for <code>Sasl.createSaslClient</code> for the + * details of the parameters. + * + * @param transport + * The underlying Thrift transport. + * @throws SaslException + */ + public TSaslClientTransport(String mechanism, String authorizationId, String protocol, + String serverName, Map<String, String> props, CallbackHandler cbh, TTransport transport) + throws SaslException, TTransportException { + super(Sasl.createSaslClient(new String[] { mechanism }, authorizationId, protocol, serverName, + props, cbh), transport); + this.mechanism = mechanism; + } + + + @Override + protected SaslRole getRole() { + return SaslRole.CLIENT; + } + + /** + * Performs the client side of the initial portion of the Thrift SASL + * protocol. Generates and sends the initial response to the server, including + * which mechanism this client wants to use. + */ + @Override + protected void handleSaslStartMessage() throws TTransportException, SaslException { + SaslClient saslClient = getSaslClient(); + + byte[] initialResponse = new byte[0]; + if (saslClient.hasInitialResponse()) + initialResponse = saslClient.evaluateChallenge(initialResponse); + + LOGGER.debug("Sending mechanism name {} and initial response of length {}", mechanism, + initialResponse.length); + + byte[] mechanismBytes = mechanism.getBytes(StandardCharsets.UTF_8); + sendSaslMessage(NegotiationStatus.START, + mechanismBytes); + // Send initial response + sendSaslMessage(saslClient.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK, + initialResponse); + underlyingTransport.flush(); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TSaslServerTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TSaslServerTransport.java new file mode 100644 index 000000000..9111712a4 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TSaslServerTransport.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.lang.ref.WeakReference; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.WeakHashMap; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +import org.apache.thrift.transport.sasl.NegotiationStatus; +import org.apache.thrift.transport.sasl.TSaslServerDefinition; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Wraps another Thrift <code>TTransport</code>, but performs SASL server + * negotiation on the call to <code>open()</code>. This class will wrap ensuing + * communication over it, if a SASL QOP is negotiated with the other party. + */ +public class TSaslServerTransport extends TSaslTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class); + + /** + * Mapping from SASL mechanism name -> all the parameters required to + * instantiate a SASL server. + */ + private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>(); + + /** + * Uses the given underlying transport. Assumes that addServerDefinition is + * called later. + * + * @param transport + * Transport underlying this one. + */ + public TSaslServerTransport(TTransport transport) throws TTransportException { + super(transport); + } + + /** + * Creates a <code>SaslServer</code> using the given SASL-specific parameters. + * See the Java documentation for <code>Sasl.createSaslServer</code> for the + * details of the parameters. + * + * @param transport + * The underlying Thrift transport. + */ + public TSaslServerTransport(String mechanism, String protocol, String serverName, + Map<String, String> props, CallbackHandler cbh, TTransport transport) throws TTransportException { + super(transport); + addServerDefinition(mechanism, protocol, serverName, props, cbh); + } + + private TSaslServerTransport(Map<String, TSaslServerDefinition> serverDefinitionMap, TTransport transport) throws TTransportException { + super(transport); + this.serverDefinitionMap.putAll(serverDefinitionMap); + } + + /** + * Add a supported server definition to this transport. See the Java + * documentation for <code>Sasl.createSaslServer</code> for the details of the + * parameters. + */ + public void addServerDefinition(String mechanism, String protocol, String serverName, + Map<String, String> props, CallbackHandler cbh) { + serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, + props, cbh)); + } + + @Override + protected SaslRole getRole() { + return SaslRole.SERVER; + } + + /** + * Performs the server side of the initial portion of the Thrift SASL protocol. + * Receives the initial response from the client, creates a SASL server using + * the mechanism requested by the client (if this server supports it), and + * sends the first challenge back to the client. + */ + @Override + protected void handleSaslStartMessage() throws TTransportException, SaslException { + SaslResponse message = receiveSaslMessage(); + + LOGGER.debug("Received start message with status {}", message.status); + if (message.status != NegotiationStatus.START) { + throw sendAndThrowMessage(NegotiationStatus.ERROR, "Expecting START status, received " + message.status); + } + + // Get the mechanism name. + String mechanismName = new String(message.payload, StandardCharsets.UTF_8); + TSaslServerDefinition serverDefinition = serverDefinitionMap.get(mechanismName); + LOGGER.debug("Received mechanism name '{}'", mechanismName); + + if (serverDefinition == null) { + throw sendAndThrowMessage(NegotiationStatus.BAD, "Unsupported mechanism type " + mechanismName); + } + SaslServer saslServer = Sasl.createSaslServer(serverDefinition.mechanism, + serverDefinition.protocol, serverDefinition.serverName, serverDefinition.props, + serverDefinition.cbh); + setSaslServer(saslServer); + } + + /** + * <code>TTransportFactory</code> to create + * <code>TSaslServerTransports</code>. Ensures that a given + * underlying <code>TTransport</code> instance receives the same + * <code>TSaslServerTransport</code>. This is kind of an awful hack to work + * around the fact that Thrift is designed assuming that + * <code>TTransport</code> instances are stateless, and thus the existing + * <code>TServers</code> use different <code>TTransport</code> instances for + * input and output. + */ + public static class Factory extends TTransportFactory { + + /** + * This is the implementation of the awful hack described above. + * <code>WeakHashMap</code> is used to ensure that we don't leak memory. + */ + private static Map<TTransport, WeakReference<TSaslServerTransport>> transportMap = + Collections.synchronizedMap(new WeakHashMap<TTransport, WeakReference<TSaslServerTransport>>()); + + /** + * Mapping from SASL mechanism name -> all the parameters required to + * instantiate a SASL server. + */ + private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>(); + + /** + * Create a new Factory. Assumes that <code>addServerDefinition</code> will + * be called later. + */ + public Factory() { + super(); + } + + /** + * Create a new <code>Factory</code>, initially with the single server + * definition given. You may still call <code>addServerDefinition</code> + * later. See the Java documentation for <code>Sasl.createSaslServer</code> + * for the details of the parameters. + */ + public Factory(String mechanism, String protocol, String serverName, + Map<String, String> props, CallbackHandler cbh) { + super(); + addServerDefinition(mechanism, protocol, serverName, props, cbh); + } + + /** + * Add a supported server definition to the transports created by this + * factory. See the Java documentation for + * <code>Sasl.createSaslServer</code> for the details of the parameters. + */ + public void addServerDefinition(String mechanism, String protocol, String serverName, + Map<String, String> props, CallbackHandler cbh) { + serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, + props, cbh)); + } + + /** + * Get a new <code>TSaslServerTransport</code> instance, or reuse the + * existing one if a <code>TSaslServerTransport</code> has already been + * created before using the given <code>TTransport</code> as an underlying + * transport. This ensures that a given underlying transport instance + * receives the same <code>TSaslServerTransport</code>. + */ + @Override + public TTransport getTransport(TTransport base) throws TTransportException { + WeakReference<TSaslServerTransport> ret = transportMap.get(base); + if (ret == null || ret.get() == null) { + LOGGER.debug("transport map does not contain key", base); + ret = new WeakReference<TSaslServerTransport>(new TSaslServerTransport(serverDefinitionMap, base)); + try { + ret.get().open(); + } catch (TTransportException e) { + LOGGER.debug("failed to open server transport", e); + throw new RuntimeException(e); + } + transportMap.put(base, ret); // No need for putIfAbsent(). + // Concurrent calls to getTransport() will pass in different TTransports. + } else { + LOGGER.debug("transport map does contain key {}", base); + } + return ret.get(); + } + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TSaslTransport.java new file mode 100644 index 000000000..b22469d2b --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TSaslTransport.java @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.TConfiguration; +import org.apache.thrift.transport.layered.TFramedTransport; +import org.apache.thrift.transport.sasl.NegotiationStatus; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A superclass for SASL client/server thrift transports. A subclass need only + * implement the <code>open</code> method. + */ +abstract class TSaslTransport extends TEndpointTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslTransport.class); + + protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF; + + protected static final int MECHANISM_NAME_BYTES = 1; + protected static final int STATUS_BYTES = 1; + protected static final int PAYLOAD_LENGTH_BYTES = 4; + + protected static enum SaslRole { + SERVER, CLIENT; + } + + /** + * Transport underlying this one. + */ + protected TTransport underlyingTransport; + + /** + * Either a SASL client or a SASL server. + */ + private SaslParticipant sasl; + + /** + * Whether or not we should wrap/unwrap reads/writes. Determined by whether or + * not a QOP is negotiated during the SASL handshake. + */ + private boolean shouldWrap = false; + + /** + * Buffer for input. + */ + private TMemoryInputTransport readBuffer; + + /** + * Buffer for output. + */ + private final TByteArrayOutputStream writeBuffer = new TByteArrayOutputStream(1024); + + /** + * Create a TSaslTransport. It's assumed that setSaslServer will be called + * later to initialize the SASL endpoint underlying this transport. + * + * @param underlyingTransport + * The thrift transport which this transport is wrapping. + */ + protected TSaslTransport(TTransport underlyingTransport) throws TTransportException { + super(Objects.isNull(underlyingTransport.getConfiguration()) ? new TConfiguration() : underlyingTransport.getConfiguration()); + this.underlyingTransport = underlyingTransport; + this.readBuffer = new TMemoryInputTransport(underlyingTransport.getConfiguration()); + } + + /** + * Create a TSaslTransport which acts as a client. + * + * @param saslClient + * The <code>SaslClient</code> which this transport will use for SASL + * negotiation. + * @param underlyingTransport + * The thrift transport which this transport is wrapping. + */ + protected TSaslTransport(SaslClient saslClient, TTransport underlyingTransport) throws TTransportException { + super(Objects.isNull(underlyingTransport.getConfiguration()) ? new TConfiguration() : underlyingTransport.getConfiguration()); + sasl = new SaslParticipant(saslClient); + this.underlyingTransport = underlyingTransport; + this.readBuffer = new TMemoryInputTransport(underlyingTransport.getConfiguration()); + } + + protected void setSaslServer(SaslServer saslServer) { + sasl = new SaslParticipant(saslServer); + } + + // Used to read the status byte and payload length. + private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES]; + + /** + * Send a complete Thrift SASL message. + * + * @param status + * The status to send. + * @param payload + * The data to send as the payload of this message. + * @throws TTransportException + */ + protected void sendSaslMessage(NegotiationStatus status, byte[] payload) throws TTransportException { + if (payload == null) + payload = new byte[0]; + + messageHeader[0] = status.getValue(); + EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES); + + LOGGER.debug("{}: Writing message with status {} and payload length {}", + getRole(), status, payload.length); + + underlyingTransport.write(messageHeader); + underlyingTransport.write(payload); + underlyingTransport.flush(); + } + + /** + * Read a complete Thrift SASL message. + * + * @return The SASL status and payload from this message. + * @throws TTransportException + * Thrown if there is a failure reading from the underlying + * transport, or if a status code of BAD or ERROR is encountered. + */ + protected SaslResponse receiveSaslMessage() throws TTransportException { + underlyingTransport.readAll(messageHeader, 0, messageHeader.length); + + byte statusByte = messageHeader[0]; + + NegotiationStatus status = NegotiationStatus.byValue(statusByte); + if (status == null) { + throw sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte); + } + + int payloadBytes = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES); + if (payloadBytes < 0 || payloadBytes > getConfiguration().getMaxMessageSize() /* 100 MB */) { + throw sendAndThrowMessage( + NegotiationStatus.ERROR, "Invalid payload header length: " + payloadBytes); + } + + byte[] payload = new byte[payloadBytes]; + underlyingTransport.readAll(payload, 0, payload.length); + + if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) { + String remoteMessage = new String(payload, StandardCharsets.UTF_8); + throw new TTransportException("Peer indicated failure: " + remoteMessage); + } + LOGGER.debug("{}: Received message with status {} and payload length {}", + getRole(), status, payload.length); + return new SaslResponse(status, payload); + } + + /** + * Send a Thrift SASL message with the given status (usually BAD or ERROR) and + * string message, and then throw a TTransportException with the given + * message. + * + * @param status + * The Thrift SASL status code to send. Usually BAD or ERROR. + * @param message + * The optional message to send to the other side. + * @throws TTransportException + * Always thrown with the message provided. + * @return always throws TTransportException but declares return type to allow + * throw sendAndThrowMessage(...) to inform compiler control flow + */ + protected TTransportException sendAndThrowMessage(NegotiationStatus status, String message) throws TTransportException { + try { + sendSaslMessage(status, message.getBytes(StandardCharsets.UTF_8)); + } catch (Exception e) { + LOGGER.warn("Could not send failure response", e); + message += "\nAlso, could not send response: " + e.toString(); + } + throw new TTransportException(message); + } + + /** + * Implemented by subclasses to start the Thrift SASL handshake process. When + * this method completes, the <code>SaslParticipant</code> in this class is + * assumed to be initialized. + * + * @throws TTransportException + * @throws SaslException + */ + abstract protected void handleSaslStartMessage() throws TTransportException, SaslException; + + protected abstract SaslRole getRole(); + + /** + * Opens the underlying transport if it's not already open and then performs + * SASL negotiation. If a QOP is negotiated during this SASL handshake, it used + * for all communication on this transport after this call is complete. + */ + @Override + public void open() throws TTransportException { + /* + * readSaslHeader is used to tag whether the SASL header has been read properly. + * If there is a problem in reading the header, there might not be any + * data in the stream, possibly a TCP health check from load balancer. + */ + boolean readSaslHeader = false; + + LOGGER.debug("opening transport {}", this); + if (sasl != null && sasl.isComplete()) + throw new TTransportException("SASL transport already open"); + + if (!underlyingTransport.isOpen()) + underlyingTransport.open(); + + try { + // Negotiate a SASL mechanism. The client also sends its + // initial response, or an empty one. + handleSaslStartMessage(); + readSaslHeader = true; + LOGGER.debug("{}: Start message handled", getRole()); + + SaslResponse message = null; + while (!sasl.isComplete()) { + message = receiveSaslMessage(); + if (message.status != NegotiationStatus.COMPLETE && + message.status != NegotiationStatus.OK) { + throw new TTransportException("Expected COMPLETE or OK, got " + message.status); + } + + byte[] challenge = sasl.evaluateChallengeOrResponse(message.payload); + + // If we are the client, and the server indicates COMPLETE, we don't need to + // send back any further response. + if (message.status == NegotiationStatus.COMPLETE && + getRole() == SaslRole.CLIENT) { + LOGGER.debug("{}: All done!", getRole()); + continue; + } + + sendSaslMessage(sasl.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK, + challenge); + } + LOGGER.debug("{}: Main negotiation loop complete", getRole()); + + // If we're the client, and we're complete, but the server isn't + // complete yet, we need to wait for its response. This will occur + // with ANONYMOUS auth, for example, where we send an initial response + // and are immediately complete. + if (getRole() == SaslRole.CLIENT && + (message == null || message.status == NegotiationStatus.OK)) { + LOGGER.debug("{}: SASL Client receiving last message", getRole()); + message = receiveSaslMessage(); + if (message.status != NegotiationStatus.COMPLETE) { + throw new TTransportException( + "Expected SASL COMPLETE, but got " + message.status); + } + } + } catch (SaslException e) { + try { + LOGGER.error("SASL negotiation failure", e); + throw sendAndThrowMessage(NegotiationStatus.BAD, e.getMessage()); + } finally { + underlyingTransport.close(); + } + } catch (TTransportException e) { + // If there is no-data or no-sasl header in the stream, + // log the failure, and clean up the underlying transport. + if (!readSaslHeader && e.getType() == TTransportException.END_OF_FILE) { + underlyingTransport.close(); + LOGGER.debug("No data or no sasl data in the stream during negotiation"); + } + throw e; + } + + String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP); + if (qop != null && !qop.equalsIgnoreCase("auth")) + shouldWrap = true; + } + + /** + * Get the underlying <code>SaslClient</code>. + * + * @return The <code>SaslClient</code>, or <code>null</code> if this transport + * is backed by a <code>SaslServer</code>. + */ + public SaslClient getSaslClient() { + return sasl.saslClient; + } + + /** + * Get the underlying transport that Sasl is using. + * @return The <code>TTransport</code> transport + */ + public TTransport getUnderlyingTransport() { + return underlyingTransport; + } + + /** + * Get the underlying <code>SaslServer</code>. + * + * @return The <code>SaslServer</code>, or <code>null</code> if this transport + * is backed by a <code>SaslClient</code>. + */ + public SaslServer getSaslServer() { + return sasl.saslServer; + } + + /** + * Read a 4-byte word from the underlying transport and interpret it as an + * integer. + * + * @return The length prefix of the next SASL message to read. + * @throws TTransportException + * Thrown if reading from the underlying transport fails. + */ + protected int readLength() throws TTransportException { + byte[] lenBuf = new byte[4]; + underlyingTransport.readAll(lenBuf, 0, lenBuf.length); + return EncodingUtils.decodeBigEndian(lenBuf); + } + + /** + * Write the given integer as 4 bytes to the underlying transport. + * + * @param length + * The length prefix of the next SASL message to write. + * @throws TTransportException + * Thrown if writing to the underlying transport fails. + */ + protected void writeLength(int length) throws TTransportException { + byte[] lenBuf = new byte[4]; + TFramedTransport.encodeFrameSize(length, lenBuf); + underlyingTransport.write(lenBuf); + } + + // Below is the SASL implementation of the TTransport interface. + + /** + * Closes the underlying transport and disposes of the SASL implementation + * underlying this transport. + */ + @Override + public void close() { + underlyingTransport.close(); + try { + sasl.dispose(); + } catch (SaslException e) { + LOGGER.warn("Failed to dispose sasl participant.", e); + } + } + + /** + * True if the underlying transport is open and the SASL handshake is + * complete. + */ + @Override + public boolean isOpen() { + return underlyingTransport.isOpen() && sasl != null && sasl.isComplete(); + } + + /** + * Read from the underlying transport. Unwraps the contents if a QOP was + * negotiated during the SASL handshake. + */ + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + if (!isOpen()) + throw new TTransportException("SASL authentication not complete"); + + int got = readBuffer.read(buf, off, len); + if (got > 0) { + return got; + } + + // Read another frame of data + try { + readFrame(); + } catch (SaslException e) { + throw new TTransportException(e); + } catch (TTransportException transportException) { + // If there is no-data or no-sasl header in the stream, log the failure, and rethrow. + if (transportException.getType() == TTransportException.END_OF_FILE) { + LOGGER.debug("No data or no sasl data in the stream during negotiation"); + } + throw transportException; + } + + return readBuffer.read(buf, off, len); + } + + /** + * Read a single frame of data from the underlying transport, unwrapping if + * necessary. + * + * @throws TTransportException + * Thrown if there's an error reading from the underlying transport. + * @throws SaslException + * Thrown if there's an error unwrapping the data. + */ + private void readFrame() throws TTransportException, SaslException { + int dataLength = readLength(); + + if (dataLength < 0) + throw new TTransportException("Read a negative frame size (" + dataLength + ")!"); + + byte[] buff = new byte[dataLength]; + LOGGER.debug("{}: reading data length: {}", getRole(), dataLength); + underlyingTransport.readAll(buff, 0, dataLength); + if (shouldWrap) { + buff = sasl.unwrap(buff, 0, buff.length); + LOGGER.debug("data length after unwrap: {}", buff.length); + } + readBuffer.reset(buff); + } + + /** + * Write to the underlying transport. + */ + @Override + public void write(byte[] buf, int off, int len) throws TTransportException { + if (!isOpen()) + throw new TTransportException("SASL authentication not complete"); + + writeBuffer.write(buf, off, len); + } + + /** + * Flushes to the underlying transport. Wraps the contents if a QOP was + * negotiated during the SASL handshake. + */ + @Override + public void flush() throws TTransportException { + byte[] buf = writeBuffer.get(); + int dataLength = writeBuffer.len(); + writeBuffer.reset(); + + if (shouldWrap) { + LOGGER.debug("data length before wrap: {}", dataLength); + try { + buf = sasl.wrap(buf, 0, dataLength); + } catch (SaslException e) { + throw new TTransportException(e); + } + dataLength = buf.length; + } + LOGGER.debug("writing data length: {}", dataLength); + writeLength(dataLength); + underlyingTransport.write(buf, 0, dataLength); + underlyingTransport.flush(); + } + + /** + * Used exclusively by readSaslMessage to return both a status and data. + */ + protected static class SaslResponse { + public NegotiationStatus status; + public byte[] payload; + + public SaslResponse(NegotiationStatus status, byte[] payload) { + this.status = status; + this.payload = payload; + } + } + + /** + * Used to abstract over the <code>SaslServer</code> and + * <code>SaslClient</code> classes, which share a lot of their interface, but + * unfortunately don't share a common superclass. + */ + private static class SaslParticipant { + // One of these will always be null. + public SaslServer saslServer; + public SaslClient saslClient; + + public SaslParticipant(SaslServer saslServer) { + this.saslServer = saslServer; + } + + public SaslParticipant(SaslClient saslClient) { + this.saslClient = saslClient; + } + + public byte[] evaluateChallengeOrResponse(byte[] challengeOrResponse) throws SaslException { + if (saslClient != null) { + return saslClient.evaluateChallenge(challengeOrResponse); + } else { + return saslServer.evaluateResponse(challengeOrResponse); + } + } + + public boolean isComplete() { + if (saslClient != null) + return saslClient.isComplete(); + else + return saslServer.isComplete(); + } + + public void dispose() throws SaslException { + if (saslClient != null) + saslClient.dispose(); + else + saslServer.dispose(); + } + + public byte[] unwrap(byte[] buf, int off, int len) throws SaslException { + if (saslClient != null) + return saslClient.unwrap(buf, off, len); + else + return saslServer.unwrap(buf, off, len); + } + + public byte[] wrap(byte[] buf, int off, int len) throws SaslException { + if (saslClient != null) + return saslClient.wrap(buf, off, len); + else + return saslServer.wrap(buf, off, len); + } + + public Object getNegotiatedProperty(String propName) { + if (saslClient != null) + return saslClient.getNegotiatedProperty(propName); + else + return saslServer.getNegotiatedProperty(propName); + } + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TSeekableFile.java b/lib/java/src/main/java/org/apache/thrift/transport/TSeekableFile.java new file mode 100644 index 000000000..e02d36f6c --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TSeekableFile.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.InputStream; +import java.io.OutputStream; +import java.io.IOException; + +public interface TSeekableFile { + + public InputStream getInputStream() throws IOException; + public OutputStream getOutputStream() throws IOException; + public void close() throws IOException; + public long length() throws IOException; + public void seek(long pos) throws IOException; +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java new file mode 100644 index 000000000..eb302fd26 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketException; + +/** + * Wrapper around ServerSocket for Thrift. + * + */ +public class TServerSocket extends TServerTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TServerSocket.class.getName()); + + /** + * Underlying ServerSocket object + */ + private ServerSocket serverSocket_ = null; + + /** + * Timeout for client sockets from accept + */ + private int clientTimeout_ = 0; + + public static class ServerSocketTransportArgs extends AbstractServerTransportArgs<ServerSocketTransportArgs> { + ServerSocket serverSocket; + + public ServerSocketTransportArgs serverSocket(ServerSocket serverSocket) { + this.serverSocket = serverSocket; + return this; + } + } + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(ServerSocket serverSocket) throws TTransportException { + this(serverSocket, 0); + } + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(ServerSocket serverSocket, int clientTimeout) throws TTransportException { + this(new ServerSocketTransportArgs().serverSocket(serverSocket).clientTimeout(clientTimeout)); + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port) throws TTransportException { + this(port, 0); + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port, int clientTimeout) throws TTransportException { + this(new InetSocketAddress(port), clientTimeout); + } + + public TServerSocket(InetSocketAddress bindAddr) throws TTransportException { + this(bindAddr, 0); + } + + public TServerSocket(InetSocketAddress bindAddr, int clientTimeout) throws TTransportException { + this(new ServerSocketTransportArgs().bindAddr(bindAddr).clientTimeout(clientTimeout)); + } + + public TServerSocket(ServerSocketTransportArgs args) throws TTransportException { + clientTimeout_ = args.clientTimeout; + if (args.serverSocket != null) { + this.serverSocket_ = args.serverSocket; + return; + } + try { + // Make server socket + serverSocket_ = new ServerSocket(); + // Prevent 2MSL delay problem on server restarts + serverSocket_.setReuseAddress(true); + // Bind to listening port + serverSocket_.bind(args.bindAddr, args.backlog); + } catch (IOException ioe) { + close(); + throw new TTransportException("Could not create ServerSocket on address " + args.bindAddr.toString() + ".", ioe); + } + } + + public void listen() throws TTransportException { + // Make sure to block on accept + if (serverSocket_ != null) { + try { + serverSocket_.setSoTimeout(0); + } catch (SocketException sx) { + LOGGER.error("Could not set socket timeout.", sx); + } + } + } + + @Override + public TSocket accept() throws TTransportException { + if (serverSocket_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket."); + } + Socket result; + try { + result = serverSocket_.accept(); + } catch (Exception e) { + throw new TTransportException(e); + } + if (result == null) { + throw new TTransportException("Blocking server's accept() may not return NULL"); + } + TSocket socket = new TSocket(result); + socket.setTimeout(clientTimeout_); + return socket; + } + + public void close() { + if (serverSocket_ != null) { + try { + serverSocket_.close(); + } catch (IOException iox) { + LOGGER.warn("Could not close server socket.", iox); + } + serverSocket_ = null; + } + } + + public void interrupt() { + // The thread-safeness of this is dubious, but Java documentation suggests + // that it is safe to do this from a different thread context + close(); + } + + public ServerSocket getServerSocket() { + return serverSocket_; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java new file mode 100644 index 000000000..3a7b49a31 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.Closeable; +import java.net.InetSocketAddress; + +import org.apache.thrift.TConfiguration; + +/** + * Server transport. Object which provides client transports. + * + */ +public abstract class TServerTransport implements Closeable { + + public static abstract class AbstractServerTransportArgs<T extends AbstractServerTransportArgs<T>> { + int backlog = 0; // A value of 0 means the default value will be used (currently set at 50) + int clientTimeout = 0; + InetSocketAddress bindAddr; + int maxFrameSize = TConfiguration.DEFAULT_MAX_FRAME_SIZE; + + public T backlog(int backlog) { + this.backlog = backlog; + return (T) this; + } + + public T clientTimeout(int clientTimeout) { + this.clientTimeout = clientTimeout; + return (T) this; + } + + public T port(int port) { + this.bindAddr = new InetSocketAddress(port); + return (T) this; + } + + public T bindAddr(InetSocketAddress bindAddr) { + this.bindAddr = bindAddr; + return (T) this; + } + + public T maxFrameSize(int maxFrameSize) { + this.maxFrameSize = maxFrameSize; + return (T) this; + } + } + + public abstract void listen() throws TTransportException; + + /** + * Accept incoming connection on the server socket. When there is no incoming connection available: + * either it should block infinitely in a blocking implementation, either it should return null in + * a nonblocking implementation. + * + * @return new connection + * @throws TTransportException if IO error. + */ + public abstract TTransport accept() throws TTransportException; + + public abstract void close(); + + /** + * Optional method implementation. This signals to the server transport + * that it should break out of any accept() or listen() that it is currently + * blocked on. This method, if implemented, MUST be thread safe, as it may + * be called from a different thread context than the other TServerTransport + * methods. + */ + public void interrupt() {} + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TSimpleFileTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TSimpleFileTransport.java new file mode 100644 index 000000000..c1bbd4853 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TSimpleFileTransport.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +import java.io.IOException; +import java.io.RandomAccessFile; + + +/** + * Basic file support for the TTransport interface + */ +public final class TSimpleFileTransport extends TEndpointTransport { + + private RandomAccessFile file = null; + private boolean readable; + private boolean writable; + private String path_; + + + /** + * Create a transport backed by a simple file + * + * @param path the path to the file to open/create + * @param read true to support read operations + * @param write true to support write operations + * @param openFile true to open the file on construction + * @throws TTransportException if file open fails + */ + public TSimpleFileTransport(String path, boolean read, + boolean write, boolean openFile) + throws TTransportException { + this(new TConfiguration(), path, read, write, openFile); + } + + /** + * Create a transport backed by a simple file + * + * @param config + * @param path the path to the file to open/create + * @param read true to support read operations + * @param write true to support write operations + * @param openFile true to open the file on construction + * @throws TTransportException if file open fails + */ + public TSimpleFileTransport(TConfiguration config, String path, boolean read, + boolean write, boolean openFile) + throws TTransportException { + super(config); + if (path.length() <= 0) { + throw new TTransportException("No path specified"); + } + if (!read && !write) { + throw new TTransportException("Neither READ nor WRITE specified"); + } + readable = read; + writable = write; + path_ = path; + if (openFile) { + open(); + } + } + + /** + * Create a transport backed by a simple file + * Implicitly opens file to conform to C++ behavior. + * + * @param path the path to the file to open/create + * @param read true to support read operations + * @param write true to support write operations + * @throws TTransportException if file open fails + */ + public TSimpleFileTransport(String path, boolean read, boolean write) + throws TTransportException { + this(path, read, write, true); + } + + /** + * Create a transport backed by a simple read only disk file (implicitly opens + * file) + * + * @param path the path to the file to open/create + * @throws TTransportException if file open fails + */ + public TSimpleFileTransport(String path) throws TTransportException { + this(path, true, false, true); + } + + /** + * Test file status + * + * @return true if open, otherwise false + */ + @Override + public boolean isOpen() { + return (file != null); + } + + /** + * Open file if not previously opened. + * + * @throws TTransportException if open fails + */ + @Override + public void open() throws TTransportException { + if (file == null){ + try { + String access = "r"; //RandomAccessFile objects must be readable + if (writable) { + access += "w"; + } + file = new RandomAccessFile(path_, access); + } catch (IOException ioe) { + file = null; + throw new TTransportException(ioe.getMessage()); + } + } + } + + /** + * Close file, subsequent read/write activity will throw exceptions + */ + @Override + public void close() { + if (file != null) { + try { + file.close(); + } catch (Exception e) { + //Nothing to do + } + file = null; + } + } + + /** + * Read up to len many bytes into buf at offset + * + * @param buf houses bytes read + * @param off offset into buff to begin writing to + * @param len maximum number of bytes to read + * @return number of bytes actually read + * @throws TTransportException on read failure + */ + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + if (!readable) { + throw new TTransportException("Read operation on write only file"); + } + checkReadBytesAvailable(len); + int iBytesRead = 0; + try { + iBytesRead = file.read(buf, off, len); + } catch (IOException ioe) { + file = null; + throw new TTransportException(ioe.getMessage()); + } + return iBytesRead; + } + + /** + * Write len many bytes from buff starting at offset + * + * @param buf buffer containing bytes to write + * @param off offset into buffer to begin writing from + * @param len number of bytes to write + * @throws TTransportException on write failure + */ + @Override + public void write(byte[] buf, int off, int len) throws TTransportException { + try { + file.write(buf, off, len); + } catch (IOException ioe) { + file = null; + throw new TTransportException(ioe.getMessage()); + } + } + + /** + * Move file pointer to specified offset, new read/write calls will act here + * + * @param offset bytes from beginning of file to move pointer to + * @throws TTransportException is seek fails + */ + public void seek(long offset) throws TTransportException { + try { + file.seek(offset); + } catch (IOException ex) { + throw new TTransportException(ex.getMessage()); + } + } + + /** + * Return the length of the file in bytes + * + * @return length of the file in bytes + * @throws TTransportException if file access fails + */ + public long length() throws TTransportException { + try { + return file.length(); + } catch (IOException ex) { + throw new TTransportException(ex.getMessage()); + } + } + + /** + * Return current file pointer position in bytes from beginning of file + * + * @return file pointer position + * @throws TTransportException if file access fails + */ + public long getFilePointer() throws TTransportException { + try { + return file.getFilePointer(); + } catch (IOException ex) { + throw new TTransportException(ex.getMessage()); + } + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TSocket.java new file mode 100644 index 000000000..aef2a3ff8 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TSocket.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketException; + +/** + * Socket implementation of the TTransport interface. To be commented soon! + * + */ +public class TSocket extends TIOStreamTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSocket.class.getName()); + + /** + * Wrapped Socket object + */ + private Socket socket_; + + /** + * Remote host + */ + private String host_; + + /** + * Remote port + */ + private int port_; + + /** + * Socket timeout - read timeout on the socket + */ + private int socketTimeout_; + + /** + * Connection timeout + */ + private int connectTimeout_; + + /** + * Constructor that takes an already created socket. + * + * @param socket Already created socket object + * @throws TTransportException if there is an error setting up the streams + */ + public TSocket(Socket socket) throws TTransportException { + super(new TConfiguration()); + socket_ = socket; + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + socket_.setKeepAlive(true); + } catch (SocketException sx) { + LOGGER.warn("Could not configure socket.", sx); + } + + if (isOpen()) { + try { + inputStream_ = new BufferedInputStream(socket_.getInputStream()); + outputStream_ = new BufferedOutputStream(socket_.getOutputStream()); + } catch (IOException iox) { + close(); + throw new TTransportException(TTransportException.NOT_OPEN, iox); + } + } + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port. + * + * @param config check config + * @param host Remote host + * @param port Remote port + */ + public TSocket(TConfiguration config, String host, int port) throws TTransportException { + this(config, host, port, 0); + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port. + * + * @param host Remote host + * @param port Remote port + */ + public TSocket(String host, int port) throws TTransportException { + this(new TConfiguration(), host, port, 0); + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port. + * + * @param host Remote host + * @param port Remote port + * @param timeout Socket timeout and connection timeout + */ + public TSocket(String host, int port, int timeout) throws TTransportException { + this(new TConfiguration(), host, port, timeout, timeout); + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port. + * + * @param config check config + * @param host Remote host + * @param port Remote port + * @param timeout Socket timeout and connection timeout + */ + public TSocket(TConfiguration config, String host, int port, int timeout) throws TTransportException { + this(config, host, port, timeout, timeout); + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port, with a specific connection timeout and a + * specific socket timeout. + * + * @param config check config + * @param host Remote host + * @param port Remote port + * @param socketTimeout Socket timeout + * @param connectTimeout Connection timeout + */ + public TSocket(TConfiguration config, String host, int port, int socketTimeout, int connectTimeout) throws TTransportException { + super(config); + host_ = host; + port_ = port; + socketTimeout_ = socketTimeout; + connectTimeout_ = connectTimeout; + initSocket(); + } + + /** + * Initializes the socket object + */ + private void initSocket() { + socket_ = new Socket(); + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + socket_.setKeepAlive(true); + socket_.setSoTimeout(socketTimeout_); + } catch (SocketException sx) { + LOGGER.error("Could not configure socket.", sx); + } + } + + /** + * Sets the socket timeout and connection timeout. + * + * @param timeout Milliseconds timeout + */ + public void setTimeout(int timeout) { + this.setConnectTimeout(timeout); + this.setSocketTimeout(timeout); + } + + /** + * Sets the time after which the connection attempt will time out + * + * @param timeout Milliseconds timeout + */ + public void setConnectTimeout(int timeout) { + connectTimeout_ = timeout; + } + + /** + * Sets the socket timeout + * + * @param timeout Milliseconds timeout + */ + public void setSocketTimeout(int timeout) { + socketTimeout_ = timeout; + try { + socket_.setSoTimeout(timeout); + } catch (SocketException sx) { + LOGGER.warn("Could not set socket timeout.", sx); + } + } + + /** + * Returns a reference to the underlying socket. + */ + public Socket getSocket() { + if (socket_ == null) { + initSocket(); + } + return socket_; + } + + /** + * Checks whether the socket is connected. + */ + public boolean isOpen() { + if (socket_ == null) { + return false; + } + return socket_.isConnected(); + } + + /** + * Connects the socket, creating a new socket object if necessary. + */ + public void open() throws TTransportException { + if (isOpen()) { + throw new TTransportException(TTransportException.ALREADY_OPEN, "Socket already connected."); + } + + if (host_ == null || host_.length() == 0) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot open null host."); + } + if (port_ <= 0 || port_ > 65535) { + throw new TTransportException(TTransportException.NOT_OPEN, "Invalid port " + port_); + } + + if (socket_ == null) { + initSocket(); + } + + try { + socket_.connect(new InetSocketAddress(host_, port_), connectTimeout_); + inputStream_ = new BufferedInputStream(socket_.getInputStream()); + outputStream_ = new BufferedOutputStream(socket_.getOutputStream()); + } catch (IOException iox) { + close(); + throw new TTransportException(TTransportException.NOT_OPEN, iox); + } + } + + /** + * Closes the socket. + */ + public void close() { + // Close the underlying streams + super.close(); + + // Close the socket + if (socket_ != null) { + try { + socket_.close(); + } catch (IOException iox) { + LOGGER.warn("Could not close socket.", iox); + } + socket_ = null; + } + } + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TStandardFile.java b/lib/java/src/main/java/org/apache/thrift/transport/TStandardFile.java new file mode 100644 index 000000000..7a33af8ee --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TStandardFile.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.InputStream; +import java.io.OutputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.io.FileInputStream; +import java.io.FileOutputStream; + +public class TStandardFile implements TSeekableFile { + + protected String path_ = null; + protected RandomAccessFile inputFile_ = null; + + public TStandardFile(String path) throws IOException { + path_ = path; + inputFile_ = new RandomAccessFile(path_, "r"); + } + + public InputStream getInputStream() throws IOException { + return new FileInputStream(inputFile_.getFD()); + } + + public OutputStream getOutputStream() throws IOException { + return new FileOutputStream(path_); + } + + public void close() throws IOException { + if(inputFile_ != null) { + inputFile_.close(); + } + } + + public long length() throws IOException { + return inputFile_.length(); + } + + public void seek(long pos) throws IOException { + inputFile_.seek(pos); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TTransport.java new file mode 100644 index 000000000..afe9cfb88 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TTransport.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +import java.io.Closeable; +import java.nio.ByteBuffer; + +/** + * Generic class that encapsulates the I/O layer. This is basically a thin + * wrapper around the combined functionality of Java input/output streams. + * + */ +public abstract class TTransport implements Closeable { + + /** + * Queries whether the transport is open. + * + * @return True if the transport is open. + */ + public abstract boolean isOpen(); + + /** + * Is there more data to be read? + * + * @return True if the remote side is still alive and feeding us + */ + public boolean peek() { + return isOpen(); + } + + /** + * Opens the transport for reading/writing. + * + * @throws TTransportException if the transport could not be opened + */ + public abstract void open() + throws TTransportException; + + /** + * Closes the transport. + */ + public abstract void close(); + + /** + * Reads a sequence of bytes from this channel into the given buffer. An + * attempt is made to read up to the number of bytes remaining in the buffer, + * that is, dst.remaining(), at the moment this method is invoked. Upon return + * the buffer's position will move forward the number of bytes read; its limit + * will not have changed. Subclasses are encouraged to provide a more + * efficient implementation of this method. + * + * @param dst The buffer into which bytes are to be transferred + * @return The number of bytes read, possibly zero, or -1 if the channel has + * reached end-of-stream + * @throws TTransportException if there was an error reading data + */ + public int read(ByteBuffer dst) throws TTransportException { + byte[] arr = new byte[dst.remaining()]; + int n = read(arr, 0, arr.length); + dst.put(arr, 0, n); + return n; + } + + /** + * Reads up to len bytes into buffer buf, starting at offset off. + * + * @param buf Array to read into + * @param off Index to start reading at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read + * @throws TTransportException if there was an error reading data + */ + public abstract int read(byte[] buf, int off, int len) + throws TTransportException; + + /** + * Guarantees that all of len bytes are actually read off the transport. + * + * @param buf Array to read into + * @param off Index to start reading at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read, which must be equal to len + * @throws TTransportException if there was an error reading data + */ + public int readAll(byte[] buf, int off, int len) + throws TTransportException { + int got = 0; + int ret = 0; + while (got < len) { + ret = read(buf, off+got, len-got); + if (ret <= 0) { + throw new TTransportException( + "Cannot read. Remote side has closed. Tried to read " + + len + + " bytes, but only got " + + got + + " bytes. (This is often indicative of an internal error on the server side. Please check your server logs.)"); + } + got += ret; + } + return got; + } + + /** + * Writes the buffer to the output + * + * @param buf The output data buffer + * @throws TTransportException if an error occurs writing data + */ + public void write(byte[] buf) throws TTransportException { + write(buf, 0, buf.length); + } + + /** + * Writes up to len bytes from the buffer. + * + * @param buf The output data buffer + * @param off The offset to start writing from + * @param len The number of bytes to write + * @throws TTransportException if there was an error writing data + */ + public abstract void write(byte[] buf, int off, int len) + throws TTransportException; + + /** + * Writes a sequence of bytes to the buffer. An attempt is made to write all + * remaining bytes in the buffer, that is, src.remaining(), at the moment this + * method is invoked. Upon return the buffer's position will updated; its limit + * will not have changed. Subclasses are encouraged to provide a more efficient + * implementation of this method. + * + * @param src The buffer from which bytes are to be retrieved + * @return The number of bytes written, possibly zero + * @throws TTransportException if there was an error writing data + */ + public int write(ByteBuffer src) throws TTransportException { + byte[] arr = new byte[src.remaining()]; + src.get(arr); + write(arr, 0, arr.length); + return arr.length; + } + + /** + * Flush any pending data out of a transport buffer. + * + * @throws TTransportException if there was an error writing out data. + */ + public void flush() + throws TTransportException {} + + /** + * Access the protocol's underlying buffer directly. If this is not a + * buffered transport, return null. + * @return protocol's Underlying buffer + */ + public byte[] getBuffer() { + return null; + } + + /** + * Return the index within the underlying buffer that specifies the next spot + * that should be read from. + * @return index within the underlying buffer that specifies the next spot + * that should be read from + */ + public int getBufferPosition() { + return 0; + } + + /** + * Get the number of bytes remaining in the underlying buffer. Returns -1 if + * this is a non-buffered transport. + * @return the number of bytes remaining in the underlying buffer. <br> Returns -1 if + * this is a non-buffered transport. + */ + public int getBytesRemainingInBuffer() { + return -1; + } + + /** + * Consume len bytes from the underlying buffer. + * @param len the number of bytes to consume from the underlying buffer. + */ + public void consumeBuffer(int len) {} + + public abstract TConfiguration getConfiguration(); + + public abstract void updateKnownMessageSize(long size) throws TTransportException; + + public abstract void checkReadBytesAvailable(long numBytes) throws TTransportException; +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TTransportException.java b/lib/java/src/main/java/org/apache/thrift/transport/TTransportException.java new file mode 100644 index 000000000..b886bc269 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TTransportException.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TException; + +/** + * Transport exceptions. + * + */ +public class TTransportException extends TException { + + private static final long serialVersionUID = 1L; + + public static final int UNKNOWN = 0; + public static final int NOT_OPEN = 1; + public static final int ALREADY_OPEN = 2; + public static final int TIMED_OUT = 3; + public static final int END_OF_FILE = 4; + public static final int CORRUPTED_DATA = 5; + + protected int type_ = UNKNOWN; + + public TTransportException() { + super(); + } + + public TTransportException(int type) { + super(); + type_ = type; + } + + public TTransportException(int type, String message) { + super(message); + type_ = type; + } + + public TTransportException(String message) { + super(message); + } + + public TTransportException(int type, Throwable cause) { + super(cause); + type_ = type; + } + + public TTransportException(Throwable cause) { + super(cause); + } + + public TTransportException(String message, Throwable cause) { + super(message, cause); + } + + public TTransportException(int type, String message, Throwable cause) { + super(message, cause); + type_ = type; + } + + public int getType() { + return type_; + } + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TTransportFactory.java b/lib/java/src/main/java/org/apache/thrift/transport/TTransportFactory.java new file mode 100644 index 000000000..e068b4beb --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TTransportFactory.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +/** + * Factory class used to create wrapped instance of Transports. + * This is used primarily in servers, which get Transports from + * a ServerTransport and then may want to mutate them (i.e. create + * a BufferedTransport from the underlying base transport) + * + */ +public class TTransportFactory { + + /** + * Return a wrapped instance of the base Transport. + * + * @param trans The base transport + * @return Wrapped Transport + */ + public TTransport getTransport(TTransport trans) throws TTransportException { + return trans; + } + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TZlibTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TZlibTransport.java new file mode 100644 index 000000000..73b21aa3f --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/TZlibTransport.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport; + +import org.apache.thrift.TConfiguration; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Objects; +import java.util.zip.Deflater; +import java.util.zip.DeflaterOutputStream; +import java.util.zip.Inflater; +import java.util.zip.InflaterInputStream; + +/** + * TZlibTransport deflates on write and inflates on read. + */ +public class TZlibTransport extends TIOStreamTransport { + + private TTransport transport_ = null; + + public static class Factory extends TTransportFactory { + public Factory() { + } + + @Override + public TTransport getTransport(TTransport base) throws TTransportException { + return new TZlibTransport(base); + } + } + + /** + * Constructs a new TZlibTransport instance. + * @param transport the underlying transport to read from and write to + */ + public TZlibTransport(TTransport transport) throws TTransportException { + this(transport, Deflater.BEST_COMPRESSION); + } + + /** + * Constructs a new TZlibTransport instance. + * @param transport the underlying transport to read from and write to + * @param compressionLevel 0 for no compression, 9 for maximum compression + */ + public TZlibTransport(TTransport transport, int compressionLevel) throws TTransportException { + super(Objects.isNull(transport.getConfiguration()) ? new TConfiguration() : transport.getConfiguration()); + transport_ = transport; + inputStream_ = new InflaterInputStream(new TTransportInputStream(transport_), new Inflater()); + outputStream_ = new DeflaterOutputStream(new TTransportOutputStream(transport_), new Deflater(compressionLevel, false), true); + } + + @Override + public boolean isOpen() { + return transport_.isOpen(); + } + + @Override + public void open() throws TTransportException { + transport_.open(); + } + + @Override + public void close() { + super.close(); + if (transport_.isOpen()) { + transport_.close(); + } + } +} + +class TTransportInputStream extends InputStream { + + private TTransport transport = null; + + public TTransportInputStream(TTransport transport) { + this.transport = transport; + } + + @Override + public int read() throws IOException { + try { + byte[] buf = new byte[1]; + transport.read(buf, 0, 1); + return buf[0]; + } catch (TTransportException e) { + throw new IOException(e); + } + } + + @Override + public int read(byte b[], int off, int len) throws IOException { + try { + return transport.read(b, off, len); + } catch (TTransportException e) { + throw new IOException(e); + } + } +} + +class TTransportOutputStream extends OutputStream { + + private TTransport transport = null; + + public TTransportOutputStream(TTransport transport) { + this.transport = transport; + } + + @Override + public void write(final int b) throws IOException { + try { + transport.write(new byte[]{(byte) b}); + } catch (TTransportException e) { + throw new IOException(e); + } + } + + @Override + public void write(byte b[], int off, int len) throws IOException { + try { + transport.write(b, off, len); + } catch (TTransportException e) { + throw new IOException(e); + } + } + + @Override + public void flush() throws IOException { + try { + transport.flush(); + } catch (TTransportException e) { + throw new IOException(e); + } + } +} + diff --git a/lib/java/src/main/java/org/apache/thrift/transport/layered/TFastFramedTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/layered/TFastFramedTransport.java new file mode 100644 index 000000000..29bf39c14 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/layered/TFastFramedTransport.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport.layered; + + +import org.apache.thrift.TConfiguration; +import org.apache.thrift.transport.*; + +import java.util.Objects; + +/** + * This transport is wire compatible with {@link TFramedTransport}, but makes + * use of reusable, expanding read and write buffers in order to avoid + * allocating new byte[]s all the time. Since the buffers only expand, you + * should probably only use this transport if your messages are not too variably + * large, unless the persistent memory cost is not an issue. + * + * This implementation is NOT threadsafe. + */ +public class TFastFramedTransport extends TLayeredTransport { + + public static class Factory extends TTransportFactory { + private final int initialCapacity; + private final int maxLength; + + public Factory() { + this(DEFAULT_BUF_CAPACITY, TConfiguration.DEFAULT_MAX_FRAME_SIZE); + } + + public Factory(int initialCapacity) { + this(initialCapacity, TConfiguration.DEFAULT_MAX_FRAME_SIZE); + } + + public Factory(int initialCapacity, int maxLength) { + this.initialCapacity = initialCapacity; + this.maxLength = maxLength; + } + + @Override + public TTransport getTransport(TTransport trans) throws TTransportException { + return new TFastFramedTransport(trans, + initialCapacity, + maxLength); + } + } + + /** + * How big should the default read and write buffers be? + */ + public static final int DEFAULT_BUF_CAPACITY = 1024; + + private final AutoExpandingBufferWriteTransport writeBuffer; + private AutoExpandingBufferReadTransport readBuffer; + private final int initialBufferCapacity; + private final byte[] i32buf = new byte[4]; + private final int maxLength; + + /** + * Create a new {@link TFastFramedTransport}. Use the defaults + * for initial buffer size and max frame length. + * @param underlying Transport that real reads and writes will go through to. + */ + public TFastFramedTransport(TTransport underlying) throws TTransportException { + this(underlying, DEFAULT_BUF_CAPACITY, TConfiguration.DEFAULT_MAX_FRAME_SIZE); + } + + /** + * Create a new {@link TFastFramedTransport}. Use the specified + * initial buffer capacity and the default max frame length. + * @param underlying Transport that real reads and writes will go through to. + * @param initialBufferCapacity The initial size of the read and write buffers. + * In practice, it's not critical to set this unless you know in advance that + * your messages are going to be very large. + */ + public TFastFramedTransport(TTransport underlying, int initialBufferCapacity) throws TTransportException { + this(underlying, initialBufferCapacity, TConfiguration.DEFAULT_MAX_FRAME_SIZE); + } + + /** + * + * @param underlying Transport that real reads and writes will go through to. + * @param initialBufferCapacity The initial size of the read and write buffers. + * In practice, it's not critical to set this unless you know in advance that + * your messages are going to be very large. (You can pass + * TFramedTransportWithReusableBuffer.DEFAULT_BUF_CAPACITY if you're only + * using this constructor because you want to set the maxLength.) + * @param maxLength The max frame size you are willing to read. You can use + * this parameter to limit how much memory can be allocated. + */ + public TFastFramedTransport(TTransport underlying, int initialBufferCapacity, int maxLength) throws TTransportException { + super(underlying); + TConfiguration config = Objects.isNull(underlying.getConfiguration()) ? new TConfiguration() : underlying.getConfiguration(); + this.maxLength = maxLength; + config.setMaxFrameSize(maxLength); + this.initialBufferCapacity = initialBufferCapacity; + readBuffer = new AutoExpandingBufferReadTransport(config, initialBufferCapacity); + writeBuffer = new AutoExpandingBufferWriteTransport(config, initialBufferCapacity, 4); + } + + @Override + public void close() { + getInnerTransport().close(); + } + + @Override + public boolean isOpen() { + return getInnerTransport().isOpen(); + } + + @Override + public void open() throws TTransportException { + getInnerTransport().open(); + } + + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + int got = readBuffer.read(buf, off, len); + if (got > 0) { + return got; + } + + // Read another frame of data + readFrame(); + + return readBuffer.read(buf, off, len); + } + + private void readFrame() throws TTransportException { + getInnerTransport().readAll(i32buf , 0, 4); + int size = TFramedTransport.decodeFrameSize(i32buf); + + if (size < 0) { + close(); + throw new TTransportException(TTransportException.CORRUPTED_DATA, "Read a negative frame size (" + size + ")!"); + } + + if (size > getInnerTransport().getConfiguration().getMaxFrameSize()) { + close(); + throw new TTransportException(TTransportException.CORRUPTED_DATA, + "Frame size (" + size + ") larger than max length (" + maxLength + ")!"); + } + + readBuffer.fill(getInnerTransport(), size); + } + + @Override + public void write(byte[] buf, int off, int len) throws TTransportException { + writeBuffer.write(buf, off, len); + } + + @Override + public void consumeBuffer(int len) { + readBuffer.consumeBuffer(len); + } + + /** + * Only clears the read buffer! + */ + public void clear() throws TTransportException { + readBuffer = new AutoExpandingBufferReadTransport(getConfiguration(), initialBufferCapacity); + } + + @Override + public void flush() throws TTransportException { + int payloadLength = writeBuffer.getLength() - 4; + byte[] data = writeBuffer.getBuf().array(); + TFramedTransport.encodeFrameSize(payloadLength, data); + getInnerTransport().write(data, 0, payloadLength + 4); + writeBuffer.reset(); + getInnerTransport().flush(); + } + + @Override + public byte[] getBuffer() { + return readBuffer.getBuffer(); + } + + @Override + public int getBufferPosition() { + return readBuffer.getBufferPosition(); + } + + @Override + public int getBytesRemainingInBuffer() { + return readBuffer.getBytesRemainingInBuffer(); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/layered/TFramedTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/layered/TFramedTransport.java new file mode 100644 index 000000000..10a9a1c17 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/layered/TFramedTransport.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.layered; + +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.TConfiguration; +import org.apache.thrift.transport.TMemoryInputTransport; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.transport.TTransportFactory; + +import java.util.Objects; + +/** + * TFramedTransport is a buffered TTransport that ensures a fully read message + * every time by preceding messages with a 4-byte frame size. + */ +public class TFramedTransport extends TLayeredTransport { + + /** + * Buffer for output + */ + private final TByteArrayOutputStream writeBuffer_ = + new TByteArrayOutputStream(1024); + + /** + * Buffer for input + */ + private final TMemoryInputTransport readBuffer_; + + public static class Factory extends TTransportFactory { + private int maxLength_; + + public Factory() { + maxLength_ = TConfiguration.DEFAULT_MAX_FRAME_SIZE; + } + + public Factory(int maxLength) { + maxLength_ = maxLength; + } + + @Override + public TTransport getTransport(TTransport base) throws TTransportException { + return new TFramedTransport(base, maxLength_); + } + } + + /** + * Something to fill in the first four bytes of the buffer + * to make room for the frame size. This allows the + * implementation to write once instead of twice. + */ + private static final byte[] sizeFiller_ = new byte[] { 0x00, 0x00, 0x00, 0x00 }; + + /** + * Constructor wraps around another transport + */ + public TFramedTransport(TTransport transport, int maxLength) throws TTransportException { + super(transport); + TConfiguration _configuration = Objects.isNull(transport.getConfiguration()) ? new TConfiguration() : transport.getConfiguration(); + _configuration.setMaxFrameSize(maxLength); + writeBuffer_.write(sizeFiller_, 0, 4); + readBuffer_= new TMemoryInputTransport(_configuration, new byte[0]); + } + + public TFramedTransport(TTransport transport) throws TTransportException { + this(transport, TConfiguration.DEFAULT_MAX_FRAME_SIZE); + } + + public void open() throws TTransportException { + getInnerTransport().open(); + } + + public boolean isOpen() { + return getInnerTransport().isOpen(); + } + + public void close() { + getInnerTransport().close(); + } + + public int read(byte[] buf, int off, int len) throws TTransportException { + int got = readBuffer_.read(buf, off, len); + if (got > 0) { + return got; + } + + // Read another frame of data + readFrame(); + + return readBuffer_.read(buf, off, len); + } + + @Override + public byte[] getBuffer() { + return readBuffer_.getBuffer(); + } + + @Override + public int getBufferPosition() { + return readBuffer_.getBufferPosition(); + } + + @Override + public int getBytesRemainingInBuffer() { + return readBuffer_.getBytesRemainingInBuffer(); + } + + @Override + public void consumeBuffer(int len) { + readBuffer_.consumeBuffer(len); + } + + public void clear() { + readBuffer_.clear(); + } + + private final byte[] i32buf = new byte[4]; + + private void readFrame() throws TTransportException { + getInnerTransport().readAll(i32buf, 0, 4); + int size = decodeFrameSize(i32buf); + + if (size < 0) { + close(); + throw new TTransportException(TTransportException.CORRUPTED_DATA, "Read a negative frame size (" + size + ")!"); + } + + if (size > getInnerTransport().getConfiguration().getMaxFrameSize()) { + close(); + throw new TTransportException(TTransportException.CORRUPTED_DATA, + "Frame size (" + size + ") larger than max length (" + getInnerTransport().getConfiguration().getMaxFrameSize() + ")!"); + } + + byte[] buff = new byte[size]; + getInnerTransport().readAll(buff, 0, size); + readBuffer_.reset(buff); + } + + public void write(byte[] buf, int off, int len) throws TTransportException { + writeBuffer_.write(buf, off, len); + } + + @Override + public void flush() throws TTransportException { + byte[] buf = writeBuffer_.get(); + int len = writeBuffer_.len() - 4; // account for the prepended frame size + writeBuffer_.reset(); + writeBuffer_.write(sizeFiller_, 0, 4); // make room for the next frame's size data + + encodeFrameSize(len, buf); // this is the frame length without the filler + getInnerTransport().write(buf, 0, len + 4); // we have to write the frame size and frame data + getInnerTransport().flush(); + } + + public static final void encodeFrameSize(final int frameSize, final byte[] buf) { + buf[0] = (byte)(0xff & (frameSize >> 24)); + buf[1] = (byte)(0xff & (frameSize >> 16)); + buf[2] = (byte)(0xff & (frameSize >> 8)); + buf[3] = (byte)(0xff & (frameSize)); + } + + public static final int decodeFrameSize(final byte[] buf) { + return + ((buf[0] & 0xff) << 24) | + ((buf[1] & 0xff) << 16) | + ((buf[2] & 0xff) << 8) | + ((buf[3] & 0xff)); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/layered/TLayeredTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/layered/TLayeredTransport.java new file mode 100644 index 000000000..69ec824ee --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/layered/TLayeredTransport.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.transport.layered; + +import org.apache.thrift.TConfiguration; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; + +import java.util.Objects; + +public abstract class TLayeredTransport extends TTransport{ + + private TTransport innerTransport; + + public TConfiguration getConfiguration() { + return innerTransport.getConfiguration(); + } + + public TLayeredTransport(TTransport transport) + { + Objects.requireNonNull(transport, "TTransport cannot be null."); + innerTransport = transport; + } + + public void updateKnownMessageSize(long size) throws TTransportException { + innerTransport.updateKnownMessageSize(size); + } + + public void checkReadBytesAvailable(long numBytes) throws TTransportException { + innerTransport.checkReadBytesAvailable(numBytes); + } + + public TTransport getInnerTransport() { + return innerTransport; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java new file mode 100644 index 000000000..2900df9c1 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +/** + * The header for data frame, it only contains a 4-byte payload size. + */ +public class DataFrameHeaderReader extends FixedSizeHeaderReader { + public static final int PAYLOAD_LENGTH_BYTES = 4; + + private int payloadSize; + + @Override + protected int headerSize() { + return PAYLOAD_LENGTH_BYTES; + } + + @Override + protected void onComplete() throws TInvalidSaslFrameException { + payloadSize = byteBuffer.getInt(0); + if (payloadSize < 0) { + throw new TInvalidSaslFrameException("Payload size is negative: " + payloadSize); + } + } + + @Override + public int payloadSize() { + return payloadSize; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameReader.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameReader.java new file mode 100644 index 000000000..e6900bbc6 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameReader.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +/** + * Frames for thrift (serialized) messages. + */ +public class DataFrameReader extends FrameReader<DataFrameHeaderReader> { + + public DataFrameReader() { + super(new DataFrameHeaderReader()); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameWriter.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameWriter.java new file mode 100644 index 000000000..a2dd15a8c --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/DataFrameWriter.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import java.nio.ByteBuffer; + +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.utils.StringUtils; + +import static org.apache.thrift.transport.sasl.DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES; + +/** + * Write frames of thrift messages. It expects an empty/null header to be provided with a payload + * to be written out. Non empty headers are considered as error. + */ +public class DataFrameWriter extends FrameWriter { + + @Override + public void withOnlyPayload(byte[] payload, int offset, int length) { + if (!isComplete()) { + throw new IllegalStateException("Previsous write is not yet complete, with " + + frameBytes.remaining() + " bytes left."); + } + frameBytes = buildFrameWithPayload(payload, offset, length); + } + + @Override + protected ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength, + byte[] payload, int payloadOffset, int payloadLength) { + if (header != null && headerLength > 0) { + throw new IllegalArgumentException("Extra header [" + StringUtils.bytesToHexString(header) + + "] offset " + payloadOffset + " length " + payloadLength); + } + return buildFrameWithPayload(payload, payloadOffset, payloadLength); + } + + private ByteBuffer buildFrameWithPayload(byte[] payload, int offset, int length) { + byte[] bytes = new byte[PAYLOAD_LENGTH_BYTES + length]; + EncodingUtils.encodeBigEndian(length, bytes, 0); + System.arraycopy(payload, offset, bytes, PAYLOAD_LENGTH_BYTES, length); + return ByteBuffer.wrap(bytes); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java new file mode 100644 index 000000000..1cbc0ace0 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.utils.StringUtils; + +import java.nio.ByteBuffer; + +/** + * Headers' size should be predefined. + */ +public abstract class FixedSizeHeaderReader implements FrameHeaderReader { + + protected final ByteBuffer byteBuffer = ByteBuffer.allocate(headerSize()); + + @Override + public boolean isComplete() { + return !byteBuffer.hasRemaining(); + } + + @Override + public void clear() { + byteBuffer.clear(); + } + + @Override + public byte[] toBytes() { + if (!isComplete()) { + throw new IllegalStateException("Header is not yet complete " + StringUtils.bytesToHexString(byteBuffer.array(), 0, byteBuffer.position())); + } + return byteBuffer.array(); + } + + @Override + public boolean read(TTransport transport) throws TTransportException { + FrameReader.readAvailable(transport, byteBuffer); + if (byteBuffer.hasRemaining()) { + return false; + } + onComplete(); + return true; + } + + /** + * @return Size of the header. + */ + protected abstract int headerSize(); + + /** + * Actions (e.g. validation) to carry out when the header is complete. + * + * @throws TTransportException + */ + protected abstract void onComplete() throws TTransportException; +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameHeaderReader.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameHeaderReader.java new file mode 100644 index 000000000..f7c659315 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameHeaderReader.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; + +/** + * Read headers for a frame. For each frame, the header contains payload size and other metadata. + */ +public interface FrameHeaderReader { + + /** + * As the thrift sasl specification states, all sasl messages (both for negotiatiing and for + * sending data) should have a header to indicate the size of the payload. + * + * @return size of the payload. + */ + int payloadSize(); + + /** + * + * @return The received bytes for the header. + * @throws IllegalStateException if isComplete returns false. + */ + byte[] toBytes(); + + /** + * @return true if this header has all its fields set. + */ + boolean isComplete(); + + /** + * Clear the header and make it available to read a new header. + */ + void clear(); + + /** + * (Nonblocking) Read fields from underlying transport layer. + * + * @param transport underlying transport. + * @return true if header is complete after read. + * @throws TSaslNegotiationException if fail to read a valid header of a sasl negotiation message. + * @throws TTransportException if io error. + */ + boolean read(TTransport transport) throws TSaslNegotiationException, TTransportException; +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameReader.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameReader.java new file mode 100644 index 000000000..acb4b738d --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameReader.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import org.apache.thrift.transport.TEOFException; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; + +import java.nio.ByteBuffer; + +/** + * Read frames from a transport. Each frame has a header and a payload. A header will indicate + * the size of the payload and other informations about how to decode payload. + * Implementations should subclass it by providing a header reader implementation. + * + * @param <T> Header type. + */ +public abstract class FrameReader<T extends FrameHeaderReader> { + private final T header; + private ByteBuffer payload; + + protected FrameReader(T header) { + this.header = header; + } + + /** + * (Nonblocking) Read available bytes out of the transport without blocking to wait for incoming + * data. + * + * @param transport TTransport + * @return true if current frame is complete after read. + * @throws TSaslNegotiationException if fail to read back a valid sasl negotiation message. + * @throws TTransportException if io error. + */ + public boolean read(TTransport transport) throws TSaslNegotiationException, TTransportException { + if (!header.isComplete()) { + if (readHeader(transport)) { + payload = ByteBuffer.allocate(header.payloadSize()); + } else { + return false; + } + } + if (header.payloadSize() == 0) { + return true; + } + return readPayload(transport); + } + + /** + * (Nonblocking) Try to read available header bytes from transport. + * + * @return true if header is complete after read. + * @throws TSaslNegotiationException if fail to read back a validd sasl negotiation header. + * @throws TTransportException if io error. + */ + private boolean readHeader(TTransport transport) throws TSaslNegotiationException, TTransportException { + return header.read(transport); + } + + /** + * (Nonblocking) Try to read available + * + * @param transport underlying transport. + * @return true if payload is complete after read. + * @throws TTransportException if io error. + */ + private boolean readPayload(TTransport transport) throws TTransportException { + readAvailable(transport, payload); + return payload.hasRemaining(); + } + + /** + * + * @return header of the frame + */ + public T getHeader() { + return header; + } + + /** + * + * @return number of bytes of the header + */ + public int getHeaderSize() { + return header.toBytes().length; + } + + /** + * + * @return byte array of the payload + */ + public byte[] getPayload() { + return payload.array(); + } + + /** + * + * @return size of the payload + */ + public int getPayloadSize() { + return header.payloadSize(); + } + + /** + * + * @return true if the reader has fully read a frame + */ + public boolean isComplete() { + return !(payload == null || payload.hasRemaining()); + } + + /** + * Reset the state of the reader so that it can be reused to read a new frame. + */ + public void clear() { + header.clear(); + payload = null; + } + + /** + * Read immediately available bytes from the transport into the byte buffer. + * + * @param transport TTransport + * @param recipient ByteBuffer + * @return number of bytes read out of the transport + * @throws TTransportException if io error + */ + static int readAvailable(TTransport transport, ByteBuffer recipient) throws TTransportException { + if (!recipient.hasRemaining()) { + throw new IllegalStateException("Trying to fill a full recipient with " + recipient.limit() + + " bytes"); + } + int currentPosition = recipient.position(); + byte[] bytes = recipient.array(); + int offset = recipient.arrayOffset() + currentPosition; + int expectedLength = recipient.remaining(); + int got = transport.read(bytes, offset, expectedLength); + if (got < 0) { + throw new TEOFException("Transport is closed, while trying to read " + expectedLength + + " bytes"); + } + recipient.position(currentPosition + got); + return got; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameWriter.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameWriter.java new file mode 100644 index 000000000..4357f13e1 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/FrameWriter.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.thrift.transport.TNonblockingTransport; +import org.apache.thrift.transport.TTransportException; + +/** + * Write frame (header and payload) to transport in a nonblocking way. + */ +public abstract class FrameWriter { + + protected ByteBuffer frameBytes; + + /** + * Provide (maybe empty) header and payload to the frame. This can be called only when isComplete + * returns true (last frame has been written out). + * + * @param header Some extra header bytes (without the 4 bytes for payload length), which will be + * the start of the frame. It can be empty, depending on the message format + * @param payload Payload as a byte array + * @throws IllegalStateException if it is called when isComplete returns false + * @throws IllegalArgumentException if header or payload is invalid + */ + public void withHeaderAndPayload(byte[] header, byte[] payload) { + if (payload == null) { + payload = new byte[0]; + } + if (header == null) { + withOnlyPayload(payload); + } else { + withHeaderAndPayload(header, 0, header.length, payload, 0, payload.length); + } + } + + /** + * Provide extra header and payload to the frame. + * + * @param header byte array containing the extra header + * @param headerOffset starting offset of the header portition + * @param headerLength length of the extra header + * @param payload byte array containing the payload + * @param payloadOffset starting offset of the payload portion + * @param payloadLength length of the payload + * @throws IllegalStateException if preivous frame is not yet complete (isComplete returns fals) + * @throws IllegalArgumentException if header or payload is invalid + */ + public void withHeaderAndPayload(byte[] header, int headerOffset, int headerLength, + byte[] payload, int payloadOffset, int payloadLength) { + if (!isComplete()) { + throw new IllegalStateException("Previsous write is not yet complete, with " + + frameBytes.remaining() + " bytes left."); + } + frameBytes = buildFrame(header, headerOffset, headerLength, payload, payloadOffset, payloadLength); + } + + /** + * Provide only payload to the frame. Throws UnsupportedOperationException if the frame expects + * a header. + * + * @param payload payload as a byte array + */ + public void withOnlyPayload(byte[] payload) { + withOnlyPayload(payload, 0, payload.length); + } + + /** + * Provide only payload to the frame. Throws UnsupportedOperationException if the frame expects + * a header. + * + * @param payload The underlying byte array as a recipient of the payload + * @param offset The offset in the byte array starting from where the payload is located + * @param length The length of the payload + */ + public abstract void withOnlyPayload(byte[] payload, int offset, int length); + + protected abstract ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength, + byte[] payload, int payloadOffset, int payloadLength); + + /** + * Nonblocking write to the underlying transport. + * + * @throws TTransportException + */ + public void write(TNonblockingTransport transport) throws TTransportException { + transport.write(frameBytes); + } + + /** + * + * @return true when no more data needs to be written out + */ + public boolean isComplete() { + return frameBytes == null || !frameBytes.hasRemaining(); + } + + /** + * Release the byte buffer. + */ + public void clear() { + frameBytes = null; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NegotiationStatus.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NegotiationStatus.java new file mode 100644 index 000000000..ad704a0a1 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NegotiationStatus.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR; + +/** + * Status bytes used during the initial Thrift SASL handshake. + */ +public enum NegotiationStatus { + START((byte)0x01), + OK((byte)0x02), + BAD((byte)0x03), + ERROR((byte)0x04), + COMPLETE((byte)0x05); + + private static final Map<Byte, NegotiationStatus> reverseMap = new HashMap<>(); + + static { + for (NegotiationStatus s : NegotiationStatus.values()) { + reverseMap.put(s.getValue(), s); + } + } + + private final byte value; + + NegotiationStatus(byte val) { + this.value = val; + } + + public byte getValue() { + return value; + } + + public static NegotiationStatus byValue(byte val) throws TSaslNegotiationException { + if (!reverseMap.containsKey(val)) { + throw new TSaslNegotiationException(PROTOCOL_ERROR, "Invalid status " + val); + } + return reverseMap.get(val); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java new file mode 100644 index 000000000..d73c3ec18 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java @@ -0,0 +1,527 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import java.nio.channels.SelectionKey; +import java.nio.charset.StandardCharsets; + +import javax.security.sasl.SaslServer; + +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.TProcessor; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.server.ServerContext; +import org.apache.thrift.server.TServerEventHandler; +import org.apache.thrift.transport.TMemoryTransport; +import org.apache.thrift.transport.TNonblockingTransport; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE; +import static org.apache.thrift.transport.sasl.NegotiationStatus.OK; + +/** + * State machine managing one sasl connection in a nonblocking way. + */ +public class NonblockingSaslHandler { + private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSaslHandler.class); + + private static final int INTEREST_NONE = 0; + private static final int INTEREST_READ = SelectionKey.OP_READ; + private static final int INTEREST_WRITE = SelectionKey.OP_WRITE; + + // Tracking the current running phase + private Phase currentPhase = Phase.INITIIALIIZING; + // Tracking the next phase on the next invocation of the state machine. + // It should be the same as current phase if current phase is not yet finished. + // Otherwise, if it is different from current phase, the statemachine is in a transition state: + // current phase is done, and next phase is not yet started. + private Phase nextPhase = currentPhase; + + // Underlying nonblocking transport + private SelectionKey selectionKey; + private TNonblockingTransport underlyingTransport; + + // APIs for intercepting event / customizing behaviors: + // Factories (decorating the base implementations) & EventHandler (intercepting) + private TSaslServerFactory saslServerFactory; + private TSaslProcessorFactory processorFactory; + private TProtocolFactory inputProtocolFactory; + private TProtocolFactory outputProtocolFactory; + private TServerEventHandler eventHandler; + private ServerContext serverContext; + // It turns out the event handler implementation in hive sometimes creates a null ServerContext. + // In order to know whether TServerEventHandler#createContext is called we use such a flag. + private boolean serverContextCreated = false; + + // Wrapper around sasl server + private ServerSaslPeer saslPeer; + + // Sasl negotiation io + private SaslNegotiationFrameReader saslResponse; + private SaslNegotiationFrameWriter saslChallenge; + // IO for request from and response to the socket + private DataFrameReader requestReader; + private DataFrameWriter responseWriter; + // If sasl is negotiated for integrity/confidentiality protection + private boolean dataProtected; + + public NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport, + TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory, + TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory, + TServerEventHandler eventHandler) { + this.selectionKey = selectionKey; + this.underlyingTransport = underlyingTransport; + this.saslServerFactory = saslServerFactory; + this.processorFactory = processorFactory; + this.inputProtocolFactory = inputProtocolFactory; + this.outputProtocolFactory = outputProtocolFactory; + this.eventHandler = eventHandler; + + saslResponse = new SaslNegotiationFrameReader(); + saslChallenge = new SaslNegotiationFrameWriter(); + requestReader = new DataFrameReader(); + responseWriter = new DataFrameWriter(); + } + + /** + * Get current phase of the state machine. + * + * @return current phase. + */ + public Phase getCurrentPhase() { + return currentPhase; + } + + /** + * Get next phase of the state machine. + * It is different from current phase iff current phase is done (and next phase not yet started). + * + * @return next phase. + */ + public Phase getNextPhase() { + return nextPhase; + } + + /** + * + * @return underlying nonblocking socket + */ + public TNonblockingTransport getUnderlyingTransport() { + return underlyingTransport; + } + + /** + * + * @return SaslServer instance + */ + public SaslServer getSaslServer() { + return saslPeer.getSaslServer(); + } + + /** + * + * @return true if current phase is done. + */ + public boolean isCurrentPhaseDone() { + return currentPhase != nextPhase; + } + + /** + * Run state machine. + * + * @throws IllegalStateException if current state is already done. + */ + public void runCurrentPhase() { + currentPhase.runStateMachine(this); + } + + /** + * When current phase is intrested in read selection, calling this will run the current phase and + * its following phases if the following ones are interested to read, until there is nothing + * available in the underlying transport. + * + * @throws IllegalStateException if is called in an irrelevant phase. + */ + public void handleRead() { + handleOps(INTEREST_READ); + } + + /** + * Similiar to handleRead. But it is for write ops. + * + * @throws IllegalStateException if it is called in an irrelevant phase. + */ + public void handleWrite() { + handleOps(INTEREST_WRITE); + } + + private void handleOps(int interestOps) { + if (currentPhase.selectionInterest != interestOps) { + throw new IllegalStateException("Current phase " + currentPhase + " but got interest " + + interestOps); + } + runCurrentPhase(); + if (isCurrentPhaseDone() && nextPhase.selectionInterest == interestOps) { + stepToNextPhase(); + handleOps(interestOps); + } + } + + /** + * When current phase is finished, it's expected to call this method first before running the + * state machine again. + * By calling this, "next phase" is marked as started (and not done), thus is ready to run. + * + * @throws IllegalArgumentException if current phase is not yet done. + */ + public void stepToNextPhase() { + if (!isCurrentPhaseDone()) { + throw new IllegalArgumentException("Not yet done with current phase: " + currentPhase); + } + LOGGER.debug("Switch phase {} to {}", currentPhase, nextPhase); + switch (nextPhase) { + case INITIIALIIZING: + throw new IllegalStateException("INITIALIZING cannot be the next phase of " + currentPhase); + default: + } + // If next phase's interest is not the same as current, nor the same as the selection key, + // we need to change interest on the selector. + if (!(nextPhase.selectionInterest == currentPhase.selectionInterest || + nextPhase.selectionInterest == selectionKey.interestOps())) { + changeSelectionInterest(nextPhase.selectionInterest); + } + currentPhase = nextPhase; + } + + private void changeSelectionInterest(int selectionInterest) { + selectionKey.interestOps(selectionInterest); + } + + // sasl negotiaion failure handling + private void failSaslNegotiation(TSaslNegotiationException e) { + LOGGER.error("Sasl negotiation failed", e); + String errorMsg = e.getDetails(); + saslChallenge.withHeaderAndPayload(new byte[]{e.getErrorType().code.getValue()}, + errorMsg.getBytes(StandardCharsets.UTF_8)); + nextPhase = Phase.WRITING_FAILURE_MESSAGE; + } + + private void fail(Exception e) { + LOGGER.error("Failed io in " + currentPhase, e); + nextPhase = Phase.CLOSING; + } + + private void failIO(TTransportException e) { + StringBuilder errorMsg = new StringBuilder("IO failure ") + .append(e.getType()) + .append(" in ") + .append(currentPhase); + if (e.getMessage() != null) { + errorMsg.append(": ").append(e.getMessage()); + } + LOGGER.error(errorMsg.toString(), e); + nextPhase = Phase.CLOSING; + } + + // Read handlings + + private void handleInitializing() { + try { + saslResponse.read(underlyingTransport); + if (saslResponse.isComplete()) { + SaslNegotiationHeaderReader startHeader = saslResponse.getHeader(); + if (startHeader.getStatus() != NegotiationStatus.START) { + throw new TInvalidSaslFrameException("Expecting START status but got " + startHeader.getStatus()); + } + String mechanism = new String(saslResponse.getPayload(), StandardCharsets.UTF_8); + saslPeer = saslServerFactory.getSaslPeer(mechanism); + saslResponse.clear(); + nextPhase = Phase.READING_SASL_RESPONSE; + } + } catch (TSaslNegotiationException e) { + failSaslNegotiation(e); + } catch (TTransportException e) { + failIO(e); + } + } + + private void handleReadingSaslResponse() { + try { + saslResponse.read(underlyingTransport); + if (saslResponse.isComplete()) { + nextPhase = Phase.EVALUATING_SASL_RESPONSE; + } + } catch (TSaslNegotiationException e) { + failSaslNegotiation(e); + } catch (TTransportException e) { + failIO(e); + } + } + + private void handleReadingRequest() { + try { + requestReader.read(underlyingTransport); + if (requestReader.isComplete()) { + nextPhase = Phase.PROCESSING; + } + } catch (TTransportException e) { + failIO(e); + } + } + + // Computation executions + + private void executeEvaluatingSaslResponse() { + if (!(saslResponse.getHeader().getStatus() == OK || saslResponse.getHeader().getStatus() == COMPLETE)) { + String error = "Expect status OK or COMPLETE, but got " + saslResponse.getHeader().getStatus(); + failSaslNegotiation(new TSaslNegotiationException(ErrorType.PROTOCOL_ERROR, error)); + return; + } + try { + byte[] response = saslResponse.getPayload(); + saslResponse.clear(); + byte[] newChallenge = saslPeer.evaluate(response); + if (saslPeer.isAuthenticated()) { + dataProtected = saslPeer.isDataProtected(); + saslChallenge.withHeaderAndPayload(new byte[]{COMPLETE.getValue()}, newChallenge); + nextPhase = Phase.WRITING_SUCCESS_MESSAGE; + } else { + saslChallenge.withHeaderAndPayload(new byte[]{OK.getValue()}, newChallenge); + nextPhase = Phase.WRITING_SASL_CHALLENGE; + } + } catch (TSaslNegotiationException e) { + failSaslNegotiation(e); + } + } + + private void executeProcessing() { + try { + byte[] inputPayload = requestReader.getPayload(); + requestReader.clear(); + byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload; + TMemoryTransport memoryTransport = new TMemoryTransport(rawInput); + TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport); + TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport); + + if (eventHandler != null) { + if (!serverContextCreated) { + serverContext = eventHandler.createContext(requestProtocol, responseProtocol); + serverContextCreated = true; + } + eventHandler.processContext(serverContext, memoryTransport, memoryTransport); + } + + TProcessor processor = processorFactory.getProcessor(this); + processor.process(requestProtocol, responseProtocol); + TByteArrayOutputStream rawOutput = memoryTransport.getOutput(); + if (rawOutput.len() == 0) { + // This is a oneway request, no response to send back. Waiting for next incoming request. + nextPhase = Phase.READING_REQUEST; + return; + } + if (dataProtected) { + byte[] outputPayload = saslPeer.wrap(rawOutput.get(), 0, rawOutput.len()); + responseWriter.withOnlyPayload(outputPayload); + } else { + responseWriter.withOnlyPayload(rawOutput.get(), 0 ,rawOutput.len()); + } + nextPhase = Phase.WRITING_RESPONSE; + } catch (TTransportException e) { + failIO(e); + } catch (Exception e) { + fail(e); + } + } + + // Write handlings + + private void handleWritingSaslChallenge() { + try { + saslChallenge.write(underlyingTransport); + if (saslChallenge.isComplete()) { + saslChallenge.clear(); + nextPhase = Phase.READING_SASL_RESPONSE; + } + } catch (TTransportException e) { + fail(e); + } + } + + private void handleWritingSuccessMessage() { + try { + saslChallenge.write(underlyingTransport); + if (saslChallenge.isComplete()) { + LOGGER.debug("Authentication is done."); + saslChallenge = null; + saslResponse = null; + nextPhase = Phase.READING_REQUEST; + } + } catch (TTransportException e) { + fail(e); + } + } + + private void handleWritingFailureMessage() { + try { + saslChallenge.write(underlyingTransport); + if (saslChallenge.isComplete()) { + nextPhase = Phase.CLOSING; + } + } catch (TTransportException e) { + fail(e); + } + } + + private void handleWritingResponse() { + try { + responseWriter.write(underlyingTransport); + if (responseWriter.isComplete()) { + responseWriter.clear(); + nextPhase = Phase.READING_REQUEST; + } + } catch (TTransportException e) { + fail(e); + } + } + + /** + * Release all the resources managed by this state machine (connection, selection and sasl server). + * To avoid being blocked, this should be invoked in the network thread that manages the selector. + */ + public void close() { + underlyingTransport.close(); + selectionKey.cancel(); + if (saslPeer != null) { + saslPeer.dispose(); + } + if (serverContextCreated) { + eventHandler.deleteContext(serverContext, + inputProtocolFactory.getProtocol(underlyingTransport), + outputProtocolFactory.getProtocol(underlyingTransport)); + } + nextPhase = Phase.CLOSED; + currentPhase = Phase.CLOSED; + LOGGER.trace("Connection closed: {}", underlyingTransport); + } + + public enum Phase { + INITIIALIIZING(INTEREST_READ) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.handleInitializing(); + } + }, + READING_SASL_RESPONSE(INTEREST_READ) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.handleReadingSaslResponse(); + } + }, + EVALUATING_SASL_RESPONSE(INTEREST_NONE) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.executeEvaluatingSaslResponse(); + } + }, + WRITING_SASL_CHALLENGE(INTEREST_WRITE) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.handleWritingSaslChallenge(); + } + }, + WRITING_SUCCESS_MESSAGE(INTEREST_WRITE) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.handleWritingSuccessMessage(); + } + }, + WRITING_FAILURE_MESSAGE(INTEREST_WRITE) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.handleWritingFailureMessage(); + } + }, + READING_REQUEST(INTEREST_READ) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.handleReadingRequest(); + } + }, + PROCESSING(INTEREST_NONE) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.executeProcessing(); + } + }, + WRITING_RESPONSE(INTEREST_WRITE) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.handleWritingResponse(); + } + }, + CLOSING(INTEREST_NONE) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + statemachine.close(); + } + }, + CLOSED(INTEREST_NONE) { + @Override + void unsafeRun(NonblockingSaslHandler statemachine) { + // Do nothing. + } + } + ; + + // The interest on the selection key during the phase + private int selectionInterest; + + Phase(int selectionInterest) { + this.selectionInterest = selectionInterest; + } + + /** + * Provide the execution to run for the state machine in current phase. The execution should + * return the next phase after running on the state machine. + * + * @param statemachine The state machine to run. + * @throws IllegalArgumentException if the state machine's current phase is different. + * @throws IllegalStateException if the state machine' current phase is already done. + */ + void runStateMachine(NonblockingSaslHandler statemachine) { + if (statemachine.currentPhase != this) { + throw new IllegalArgumentException("State machine is " + statemachine.currentPhase + + " but is expected to be " + this); + } + if (statemachine.isCurrentPhaseDone()) { + throw new IllegalStateException("State machine should step into " + statemachine.nextPhase); + } + unsafeRun(statemachine); + } + + // Run the state machine without checkiing its own phase + // It should not be called direcly by users. + abstract void unsafeRun(NonblockingSaslHandler statemachine); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java new file mode 100644 index 000000000..01c172836 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +/** + * Read frames for sasl negotiatiions. + */ +public class SaslNegotiationFrameReader extends FrameReader<SaslNegotiationHeaderReader> { + + public SaslNegotiationFrameReader() { + super(new SaslNegotiationHeaderReader()); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java new file mode 100644 index 000000000..1e9ad1570 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import java.nio.ByteBuffer; + +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.utils.StringUtils; + +import static org.apache.thrift.transport.sasl.SaslNegotiationHeaderReader.PAYLOAD_LENGTH_BYTES; +import static org.apache.thrift.transport.sasl.SaslNegotiationHeaderReader.STATUS_BYTES; + +/** + * Writer for sasl negotiation frames. It expect a status byte as header with a payload to be + * written out (any header whose size is not equal to 1 would be considered as error). + */ +public class SaslNegotiationFrameWriter extends FrameWriter { + + public static final int HEADER_BYTES = STATUS_BYTES + PAYLOAD_LENGTH_BYTES; + + @Override + public void withOnlyPayload(byte[] payload, int offset, int length) { + throw new UnsupportedOperationException("Status byte is expected for sasl frame header."); + } + + @Override + protected ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength, + byte[] payload, int payloadOffset, int payloadLength) { + if (header == null || headerLength != STATUS_BYTES) { + throw new IllegalArgumentException("Header " + StringUtils.bytesToHexString(header) + + " does not have expected length " + STATUS_BYTES); + } + byte[] bytes = new byte[HEADER_BYTES + payloadLength]; + System.arraycopy(header, headerOffset, bytes, 0, STATUS_BYTES); + EncodingUtils.encodeBigEndian(payloadLength, bytes, STATUS_BYTES); + System.arraycopy(payload, payloadOffset, bytes, HEADER_BYTES, payloadLength); + return ByteBuffer.wrap(bytes); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java new file mode 100644 index 000000000..2d76ddb29 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR; + +/** + * Header for sasl negotiation frames. It contains status byte of negotiation and a 4-byte integer + * (payload size). + */ +public class SaslNegotiationHeaderReader extends FixedSizeHeaderReader { + public static final int STATUS_BYTES = 1; + public static final int PAYLOAD_LENGTH_BYTES = 4; + + private NegotiationStatus negotiationStatus; + private int payloadSize; + + @Override + protected int headerSize() { + return STATUS_BYTES + PAYLOAD_LENGTH_BYTES; + } + + @Override + protected void onComplete() throws TSaslNegotiationException { + negotiationStatus = NegotiationStatus.byValue(byteBuffer.get(0)); + payloadSize = byteBuffer.getInt(1); + if (payloadSize < 0) { + throw new TSaslNegotiationException(PROTOCOL_ERROR, "Payload size is negative: " + payloadSize); + } + } + + @Override + public int payloadSize() { + return payloadSize; + } + + public NegotiationStatus getStatus() { + return negotiationStatus; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslPeer.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslPeer.java new file mode 100644 index 000000000..8f8138044 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/SaslPeer.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import org.apache.thrift.transport.TTransportException; + +/** + * A peer in a sasl negotiation. + */ +public interface SaslPeer { + + /** + * Evaluate and validate the negotiation message (response/challenge) received from peer. + * + * @param negotiationMessage response/challenge received from peer. + * @return new response/challenge to send to peer, can be null if authentication becomes success. + * @throws TSaslNegotiationException if sasl authentication fails. + */ + byte[] evaluate(byte[] negotiationMessage) throws TSaslNegotiationException; + + /** + * @return true if authentication is done. + */ + boolean isAuthenticated(); + + /** + * This method can only be called when the negotiation is complete (isAuthenticated returns true). + * Otherwise it will throw IllegalStateExceptiion. + * + * @return if the qop requires some integrity/confidential protection. + * @throws IllegalStateException if negotiation is not yet complete. + */ + boolean isDataProtected(); + + /** + * Wrap raw bytes to protect it. + * + * @param data raw bytes. + * @param offset the start position of the content to wrap. + * @param length the length of the content to wrap. + * @return bytes with protection to send to peer. + * @throws TTransportException if failure. + */ + byte[] wrap(byte[] data, int offset, int length) throws TTransportException; + + /** + * Wrap the whole byte array. + * + * @param data raw bytes. + * @return wrapped bytes. + * @throws TTransportException if failure. + */ + default byte[] wrap(byte[] data) throws TTransportException { + return wrap(data, 0, data.length); + } + + /** + * Unwrap protected data to raw bytes. + * + * @param data protected data received from peer. + * @param offset the start position of the content to unwrap. + * @param length the length of the content to unwrap. + * @return raw bytes. + * @throws TTransportException if failed. + */ + byte[] unwrap(byte[] data, int offset, int length) throws TTransportException; + + /** + * Unwrap the whole byte array. + * + * @param data wrapped bytes. + * @return raw bytes. + * @throws TTransportException if failure. + */ + default byte[] unwrap(byte[] data) throws TTransportException { + return unwrap(data, 0, data.length); + } + + /** + * Close this peer and release resources. + */ + void dispose(); +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/ServerSaslPeer.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/ServerSaslPeer.java new file mode 100644 index 000000000..31992e5fc --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/ServerSaslPeer.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +import org.apache.thrift.transport.TTransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.AUTHENTICATION_FAILURE; + +/** + * Server side sasl peer, a wrapper around SaslServer to provide some handy methods. + */ +public class ServerSaslPeer implements SaslPeer { + private static final Logger LOGGER = LoggerFactory.getLogger(ServerSaslPeer.class); + + private static final String QOP_AUTH_INT = "auth-int"; + private static final String QOP_AUTH_CONF = "auth-conf"; + + private final SaslServer saslServer; + + public ServerSaslPeer(SaslServer saslServer) { + this.saslServer = saslServer; + } + + @Override + public byte[] evaluate(byte[] negotiationMessage) throws TSaslNegotiationException { + try { + return saslServer.evaluateResponse(negotiationMessage); + } catch (SaslException e) { + throw new TSaslNegotiationException(AUTHENTICATION_FAILURE, + "Authentication failed with " + saslServer.getMechanismName(), e); + } + } + + @Override + public boolean isAuthenticated() { + return saslServer.isComplete(); + } + + @Override + public boolean isDataProtected() { + Object qop = saslServer.getNegotiatedProperty(Sasl.QOP); + if (qop == null) { + return false; + } + for (String word : qop.toString().split("\\s*,\\s*")) { + String lowerCaseWord = word.toLowerCase(); + if (QOP_AUTH_INT.equals(lowerCaseWord) || QOP_AUTH_CONF.equals(lowerCaseWord)) { + return true; + } + } + return false; + } + + @Override + public byte[] wrap(byte[] data, int offset, int length) throws TTransportException { + try { + return saslServer.wrap(data, offset, length); + } catch (SaslException e) { + throw new TTransportException("Failed to wrap data", e); + } + } + + @Override + public byte[] unwrap(byte[] data, int offset, int length) throws TTransportException { + try { + return saslServer.unwrap(data, offset, length); + } catch (SaslException e) { + throw new TTransportException(TTransportException.CORRUPTED_DATA, "Failed to unwrap data", e); + } + } + + @Override + public void dispose() { + try { + saslServer.dispose(); + } catch (Exception e) { + LOGGER.warn("Failed to close sasl server " + saslServer.getMechanismName(), e); + } + } + + SaslServer getSaslServer() { + return saslServer; + } + +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java new file mode 100644 index 000000000..c08884c22 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import org.apache.thrift.TProcessor; + +public class TBaseSaslProcessorFactory implements TSaslProcessorFactory { + + private final TProcessor processor; + + public TBaseSaslProcessorFactory(TProcessor processor) { + this.processor = processor; + } + + @Override + public TProcessor getProcessor(NonblockingSaslHandler saslHandler) { + return processor; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java new file mode 100644 index 000000000..ff57ea5c4 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +/** + * Got an invalid frame that does not respect the thrift sasl protocol. + */ +public class TInvalidSaslFrameException extends TSaslNegotiationException { + + public TInvalidSaslFrameException(String message) { + super(ErrorType.PROTOCOL_ERROR, message); + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslNegotiationException.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslNegotiationException.java new file mode 100644 index 000000000..9b1fa060e --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslNegotiationException.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import org.apache.thrift.transport.TTransportException; + +/** + * Exception for sasl negotiation errors. + */ +public class TSaslNegotiationException extends TTransportException { + + private final ErrorType error; + + public TSaslNegotiationException(ErrorType error, String summary) { + super(summary); + this.error = error; + } + + public TSaslNegotiationException(ErrorType error, String summary, Throwable cause) { + super(summary, cause); + this.error = error; + } + + public ErrorType getErrorType() { + return error; + } + + /** + * @return Errory type plus the message. + */ + public String getSummary() { + return error.name() + ": " + getMessage(); + } + + /** + * @return Summary and eventually the cause's message. + */ + public String getDetails() { + return getCause() == null ? getSummary() : getSummary() + "\nReason: " + getCause().getMessage(); + } + + public enum ErrorType { + // Unexpected system internal error during negotiation (e.g. sasl initialization failure) + INTERNAL_ERROR(NegotiationStatus.ERROR), + // Cannot read correct sasl frames from the connection => Send "ERROR" status byte to peer + PROTOCOL_ERROR(NegotiationStatus.ERROR), + // Peer is using unsupported sasl mechanisms => Send "BAD" status byte to peer + MECHANISME_MISMATCH(NegotiationStatus.BAD), + // Sasl authentication failure => Send "BAD" status byte to peer + AUTHENTICATION_FAILURE(NegotiationStatus.BAD), + ; + + public final NegotiationStatus code; + + ErrorType(NegotiationStatus code) { + this.code = code; + } + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java new file mode 100644 index 000000000..877d0496f --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import org.apache.thrift.TException; +import org.apache.thrift.TProcessor; + +/** + * Get processor for a given state machine, so that users can customize the behavior of a TProcessor + * by interacting with the state machine. + */ +public interface TSaslProcessorFactory { + + TProcessor getProcessor(NonblockingSaslHandler saslHandler) throws TException; +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslServerDefinition.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslServerDefinition.java new file mode 100644 index 000000000..5486641d8 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslServerDefinition.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import javax.security.auth.callback.CallbackHandler; +import java.util.Map; + +/** + * Contains all the parameters used to define a SASL server implementation. + */ +public class TSaslServerDefinition { + public final String mechanism; + public final String protocol; + public final String serverName; + public final Map<String, String> props; + public final CallbackHandler cbh; + + public TSaslServerDefinition(String mechanism, String protocol, String serverName, + Map<String, String> props, CallbackHandler cbh) { + this.mechanism = mechanism; + this.protocol = protocol; + this.serverName = serverName; + this.props = props; + this.cbh = cbh; + } +} diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslServerFactory.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslServerFactory.java new file mode 100644 index 000000000..06cf534b6 --- /dev/null +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/TSaslServerFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport.sasl; + +import java.util.HashMap; +import java.util.Map; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.MECHANISME_MISMATCH; +import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR; + +/** + * Factory to create sasl server. Users can extend this class to customize the SaslServer creation. + */ +public class TSaslServerFactory { + + private final Map<String, TSaslServerDefinition> saslMechanisms; + + public TSaslServerFactory() { + this.saslMechanisms = new HashMap<>(); + } + + public void addSaslMechanism(String mechanism, String protocol, String serverName, + Map<String, String> props, CallbackHandler cbh) { + TSaslServerDefinition definition = new TSaslServerDefinition(mechanism, protocol, serverName, + props, cbh); + saslMechanisms.put(definition.mechanism, definition); + } + + public ServerSaslPeer getSaslPeer(String mechanism) throws TSaslNegotiationException { + if (!saslMechanisms.containsKey(mechanism)) { + throw new TSaslNegotiationException(MECHANISME_MISMATCH, "Unsupported mechanism " + mechanism); + } + TSaslServerDefinition saslDef = saslMechanisms.get(mechanism); + try { + SaslServer saslServer = Sasl.createSaslServer(saslDef.mechanism, saslDef.protocol, + saslDef.serverName, saslDef.props, saslDef.cbh); + return new ServerSaslPeer(saslServer); + } catch (SaslException e) { + throw new TSaslNegotiationException(PROTOCOL_ERROR, "Fail to create sasl server " + mechanism, e); + } + } +} |