diff options
author | Nobuaki Sukegawa <nsuke@apache.org> | 2016-01-24 04:01:27 +0900 |
---|---|---|
committer | Jens Geyer <jensg@apache.org> | 2017-12-03 17:45:33 +0100 |
commit | fc0ff81ee7d4aa95a041c826dd5a83239ef98780 (patch) | |
tree | 51cca76f2a8f1d23de9f09d585f9fa6923ae51ee | |
parent | 1310dc1eb4014457a667a7287d1fa113432c7a54 (diff) | |
download | thrift-fc0ff81ee7d4aa95a041c826dd5a83239ef98780.tar.gz |
THRIFT-3580 THeader for Haskell
Client: hs
This closes #820
This closes #1423
-rw-r--r-- | compiler/cpp/src/thrift/generate/t_hs_generator.cc | 64 | ||||
-rw-r--r-- | lib/hs/src/Thrift.hs | 4 | ||||
-rw-r--r-- | lib/hs/src/Thrift/Protocol.hs | 41 | ||||
-rw-r--r-- | lib/hs/src/Thrift/Protocol/Binary.hs | 59 | ||||
-rw-r--r-- | lib/hs/src/Thrift/Protocol/Compact.hs | 62 | ||||
-rw-r--r-- | lib/hs/src/Thrift/Protocol/Header.hs | 141 | ||||
-rw-r--r-- | lib/hs/src/Thrift/Protocol/JSON.hs | 58 | ||||
-rw-r--r-- | lib/hs/src/Thrift/Server.hs | 6 | ||||
-rw-r--r-- | lib/hs/src/Thrift/Transport/Handle.hs | 14 | ||||
-rw-r--r-- | lib/hs/src/Thrift/Transport/Header.hs | 354 | ||||
-rw-r--r-- | lib/hs/thrift.cabal | 2 | ||||
-rw-r--r-- | test/hs/TestClient.hs | 6 | ||||
-rw-r--r-- | test/hs/TestServer.hs | 15 | ||||
-rw-r--r-- | test/known_failures_Linux.json | 4 | ||||
-rw-r--r-- | test/tests.json | 1 |
15 files changed, 690 insertions, 141 deletions
diff --git a/compiler/cpp/src/thrift/generate/t_hs_generator.cc b/compiler/cpp/src/thrift/generate/t_hs_generator.cc index 30eb8fa9a..d0a8cb2d6 100644 --- a/compiler/cpp/src/thrift/generate/t_hs_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_hs_generator.cc @@ -711,13 +711,13 @@ void t_hs_generator::generate_hs_struct_reader(ofstream& out, t_struct* tstruct) string tmap = type_name(tstruct, "typemap_"); indent(out) << "to_" << sname << " _ = P.error \"not a struct\"" << endl; - indent(out) << "read_" << sname << " :: (T.Transport t, T.Protocol p) => p t -> P.IO " << sname + indent(out) << "read_" << sname << " :: T.Protocol p => p -> P.IO " << sname << endl; indent(out) << "read_" << sname << " iprot = to_" << sname; out << " <$> T.readVal iprot (T.T_STRUCT " << tmap << ")" << endl; indent(out) << "decode_" << sname - << " :: (T.Protocol p, T.Transport t) => p t -> LBS.ByteString -> " << sname << endl; + << " :: T.StatelessProtocol p => p -> LBS.ByteString -> " << sname << endl; indent(out) << "decode_" << sname << " iprot bs = to_" << sname << " $ "; out << "T.deserializeVal iprot (T.T_STRUCT " << tmap << ") bs" << endl; } @@ -818,13 +818,13 @@ void t_hs_generator::generate_hs_struct_writer(ofstream& out, t_struct* tstruct) indent_down(); // write - indent(out) << "write_" << name << " :: (T.Protocol p, T.Transport t) => p t -> " << name + indent(out) << "write_" << name << " :: T.Protocol p => p -> " << name << " -> P.IO ()" << endl; indent(out) << "write_" << name << " oprot record = T.writeVal oprot $ from_"; out << name << " record" << endl; // encode - indent(out) << "encode_" << name << " :: (T.Protocol p, T.Transport t) => p t -> " << name + indent(out) << "encode_" << name << " :: T.StatelessProtocol p => p -> " << name << " -> LBS.ByteString" << endl; indent(out) << "encode_" << name << " oprot record = T.serializeVal oprot $ "; out << "from_" << name << " record" << endl; @@ -1085,8 +1085,9 @@ void t_hs_generator::generate_service_client(t_service* tservice) { // Serialize the request header string fname = (*f_iter)->get_name(); string msgType = (*f_iter)->is_oneway() ? "T.M_ONEWAY" : "T.M_CALL"; - indent(f_client_) << "T.writeMessageBegin op (\"" << fname << "\", " << msgType << ", seqn)" + indent(f_client_) << "T.writeMessage op (\"" << fname << "\", " << msgType << ", seqn) $" << endl; + indent_up(); indent(f_client_) << "write_" << argsname << " op (" << argsname << "{"; bool first = true; @@ -1102,10 +1103,7 @@ void t_hs_generator::generate_service_client(t_service* tservice) { first = false; } f_client_ << "})" << endl; - indent(f_client_) << "T.writeMessageEnd op" << endl; - - // Write to the stream - indent(f_client_) << "T.tFlush (T.getTransport op)" << endl; + indent_down(); indent_down(); if (!(*f_iter)->is_oneway()) { @@ -1119,12 +1117,12 @@ void t_hs_generator::generate_service_client(t_service* tservice) { indent(f_client_) << funname << " ip = do" << endl; indent_up(); - indent(f_client_) << "(fname, mtype, rseqid) <- T.readMessageBegin ip" << endl; + indent(f_client_) << "T.readMessage ip $ \\(fname, mtype, rseqid) -> do" << endl; + indent_up(); indent(f_client_) << "M.when (mtype == T.M_EXCEPTION) $ do { exn <- T.readAppExn ip ; " - "T.readMessageEnd ip ; X.throw exn }" << endl; + "X.throw exn }" << endl; indent(f_client_) << "res <- read_" << resultname << " ip" << endl; - indent(f_client_) << "T.readMessageEnd ip" << endl; t_struct* xs = (*f_iter)->get_xceptions(); const vector<t_field*>& xceptions = xs->get_members(); @@ -1142,6 +1140,7 @@ void t_hs_generator::generate_service_client(t_service* tservice) { // Close function indent_down(); + indent_down(); } } @@ -1180,11 +1179,11 @@ void t_hs_generator::generate_service_server(t_service* tservice) { f_service_ << "do" << endl; indent_up(); indent(f_service_) << "_ <- T.readVal iprot (T.T_STRUCT Map.empty)" << endl; - indent(f_service_) << "T.writeMessageBegin oprot (name,T.M_EXCEPTION,seqid)" << endl; + indent(f_service_) << "T.writeMessage oprot (name,T.M_EXCEPTION,seqid) $" << endl; + indent_up(); indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN_METHOD (\"Unknown function " "\" ++ LT.unpack name))" << endl; - indent(f_service_) << "T.writeMessageEnd oprot" << endl; - indent(f_service_) << "T.tFlush (T.getTransport oprot)" << endl; + indent_down(); indent_down(); } @@ -1194,9 +1193,8 @@ void t_hs_generator::generate_service_server(t_service* tservice) { indent(f_service_) << "process handler (iprot, oprot) = do" << endl; indent_up(); - indent(f_service_) << "(name, typ, seqid) <- T.readMessageBegin iprot" << endl; - indent(f_service_) << "proc_ handler (iprot,oprot) (name,typ,seqid)" << endl; - indent(f_service_) << "T.readMessageEnd iprot" << endl; + indent(f_service_) << "T.readMessage iprot (" << endl; + indent(f_service_) << " proc_ handler (iprot,oprot))" << endl; indent(f_service_) << "P.return P.True" << endl; indent_down(); } @@ -1286,11 +1284,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function* if (tfunction->is_oneway()) { indent(f_service_) << "P.return ()"; } else { - indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name() - << "\", T.M_REPLY, seqid)" << endl; - indent(f_service_) << "write_" << resultname << " oprot res" << endl; - indent(f_service_) << "T.writeMessageEnd oprot" << endl; - indent(f_service_) << "T.tFlush (T.getTransport oprot)"; + indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name() + << "\", T.M_REPLY, seqid) $" << endl; + indent_up(); + indent(f_service_) << "write_" << resultname << " oprot res"; + indent_down(); } if (n > 0) { f_service_ << ")"; @@ -1307,11 +1305,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function* indent(f_service_) << "let res = default_" << resultname << "{" << field_name(resultname, (*x_iter)->get_name()) << " = P.Just e}" << endl; - indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name() - << "\", T.M_REPLY, seqid)" << endl; - indent(f_service_) << "write_" << resultname << " oprot res" << endl; - indent(f_service_) << "T.writeMessageEnd oprot" << endl; - indent(f_service_) << "T.tFlush (T.getTransport oprot)"; + indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name() + << "\", T.M_REPLY, seqid) $" << endl; + indent_up(); + indent(f_service_) << "write_" << resultname << " oprot res"; + indent_down(); } else { indent(f_service_) << "P.return ()"; } @@ -1324,11 +1322,11 @@ void t_hs_generator::generate_process_function(t_service* tservice, t_function* indent_up(); if (!tfunction->is_oneway()) { - indent(f_service_) << "T.writeMessageBegin oprot (\"" << tfunction->get_name() - << "\", T.M_EXCEPTION, seqid)" << endl; - indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN \"\")" << endl; - indent(f_service_) << "T.writeMessageEnd oprot" << endl; - indent(f_service_) << "T.tFlush (T.getTransport oprot)"; + indent(f_service_) << "T.writeMessage oprot (\"" << tfunction->get_name() + << "\", T.M_EXCEPTION, seqid) $" << endl; + indent_up(); + indent(f_service_) << "T.writeAppExn oprot (T.AppExn T.AE_UNKNOWN \"\")"; + indent_down(); } else { indent(f_service_) << "P.return ()"; } diff --git a/lib/hs/src/Thrift.hs b/lib/hs/src/Thrift.hs index 58a304b6e..658020991 100644 --- a/lib/hs/src/Thrift.hs +++ b/lib/hs/src/Thrift.hs @@ -90,13 +90,13 @@ data AppExn = AppExn { ae_type :: AppExnType, ae_message :: String } deriving ( Show, Typeable ) instance Exception AppExn -writeAppExn :: (Protocol p, Transport t) => p t -> AppExn -> IO () +writeAppExn :: Protocol p => p -> AppExn -> IO () writeAppExn pt ae = writeVal pt $ TStruct $ Map.fromList [ (1, ("message", TString $ encodeUtf8 $ pack $ ae_message ae)) , (2, ("type", TI32 $ fromIntegral $ fromEnum (ae_type ae))) ] -readAppExn :: (Protocol p, Transport t) => p t -> IO AppExn +readAppExn :: Protocol p => p -> IO AppExn readAppExn pt = do let typemap = Map.fromList [(1,("message",T_STRING)),(2,("type",T_I32))] TStruct fields <- readVal pt $ T_STRUCT typemap diff --git a/lib/hs/src/Thrift/Protocol.hs b/lib/hs/src/Thrift/Protocol.hs index ed779a27d..67a9175cb 100644 --- a/lib/hs/src/Thrift/Protocol.hs +++ b/lib/hs/src/Thrift/Protocol.hs @@ -22,12 +22,11 @@ module Thrift.Protocol ( Protocol(..) + , StatelessProtocol(..) , ProtocolExn(..) , ProtocolExnType(..) , getTypeOf , runParser - , versionMask - , version1 , bsToDouble , bsToDoubleLE ) where @@ -35,7 +34,6 @@ module Thrift.Protocol import Control.Exception import Data.Attoparsec.ByteString import Data.Bits -import Data.ByteString.Lazy (ByteString, toStrict) import Data.ByteString.Unsafe import Data.Functor ((<$>)) import Data.Int @@ -44,37 +42,26 @@ import Data.Text.Lazy (Text) import Data.Typeable (Typeable) import Data.Word import Foreign.Ptr (castPtr) -import Foreign.Storable (Storable, peek, poke) +import Foreign.Storable (peek, poke) import System.IO.Unsafe import qualified Data.ByteString as BS import qualified Data.HashMap.Strict as Map +import qualified Data.ByteString.Lazy as LBS -import Thrift.Types import Thrift.Transport - -versionMask :: Int32 -versionMask = fromIntegral (0xffff0000 :: Word32) - -version1 :: Int32 -version1 = fromIntegral (0x80010000 :: Word32) +import Thrift.Types class Protocol a where - getTransport :: Transport t => a t -> t - - writeMessageBegin :: Transport t => a t -> (Text, MessageType, Int32) -> IO () - writeMessageEnd :: Transport t => a t -> IO () - writeMessageEnd _ = return () - - readMessageBegin :: Transport t => a t -> IO (Text, MessageType, Int32) - readMessageEnd :: Transport t => a t -> IO () - readMessageEnd _ = return () + readByte :: a -> IO LBS.ByteString + readVal :: a -> ThriftType -> IO ThriftVal + readMessage :: a -> ((Text, MessageType, Int32) -> IO b) -> IO b - serializeVal :: Transport t => a t -> ThriftVal -> ByteString - deserializeVal :: Transport t => a t -> ThriftType -> ByteString -> ThriftVal + writeVal :: a -> ThriftVal -> IO () + writeMessage :: a -> (Text, MessageType, Int32) -> IO () -> IO () - writeVal :: Transport t => a t -> ThriftVal -> IO () - writeVal p = tWrite (getTransport p) . serializeVal p - readVal :: Transport t => a t -> ThriftType -> IO ThriftVal +class Protocol a => StatelessProtocol a where + serializeVal :: a -> ThriftVal -> LBS.ByteString + deserializeVal :: a -> ThriftType -> LBS.ByteString -> ThriftVal data ProtocolExnType = PE_UNKNOWN @@ -105,10 +92,10 @@ getTypeOf v = case v of TBinary{} -> T_BINARY TDouble{} -> T_DOUBLE -runParser :: (Protocol p, Transport t, Show a) => p t -> Parser a -> IO a +runParser :: (Protocol p, Show a) => p -> Parser a -> IO a runParser prot p = refill >>= getResult . parse p where - refill = handle handleEOF $ toStrict <$> tReadAll (getTransport prot) 1 + refill = handle handleEOF $ LBS.toStrict <$> readByte prot getResult (Done _ a) = return a getResult (Partial k) = refill >>= getResult . k getResult f = throw $ ProtocolExn PE_INVALID_DATA (show f) diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs index 2d35305dc..7b0acd9d4 100644 --- a/lib/hs/src/Thrift/Protocol/Binary.hs +++ b/lib/hs/src/Thrift/Protocol/Binary.hs @@ -25,6 +25,8 @@ module Thrift.Protocol.Binary ( module Thrift.Protocol , BinaryProtocol(..) + , versionMask + , version1 ) where import Control.Exception ( throw ) @@ -35,6 +37,7 @@ import Data.Functor import Data.Int import Data.Monoid import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 ) +import Data.Word import Thrift.Protocol import Thrift.Transport @@ -47,37 +50,55 @@ import qualified Data.ByteString.Lazy as LBS import qualified Data.HashMap.Strict as Map import qualified Data.Text.Lazy as LT -data BinaryProtocol a = BinaryProtocol a +versionMask :: Int32 +versionMask = fromIntegral (0xffff0000 :: Word32) + +version1 :: Int32 +version1 = fromIntegral (0x80010000 :: Word32) + +data BinaryProtocol a = Transport a => BinaryProtocol a + +getTransport :: Transport t => BinaryProtocol t -> t +getTransport (BinaryProtocol t) = t -- NOTE: Reading and Writing functions rely on Builders and Data.Binary to -- encode and decode data. Data.Binary assumes that the binary values it is -- encoding to and decoding from are in BIG ENDIAN format, and converts the -- endianness as necessary to match the local machine. -instance Protocol BinaryProtocol where - getTransport (BinaryProtocol t) = t - - writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $ - buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <> - buildBinaryValue (TString $ encodeUtf8 n) <> - buildBinaryValue (TI32 s) - - readMessageBegin p = runParser p $ do - TI32 ver <- parseBinaryValue T_I32 - if ver .&. versionMask /= version1 - then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier" - else do - TString s <- parseBinaryValue T_STRING - TI32 sz <- parseBinaryValue T_I32 - return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz) +instance Transport t => Protocol (BinaryProtocol t) where + readByte p = tReadAll (getTransport p) 1 + -- flushTransport p = tFlush (getTransport p) + writeMessage p (n, t, s) f = do + tWrite (getTransport p) messageBegin + f + tFlush $ getTransport p + where + messageBegin = toLazyByteString $ + buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <> + buildBinaryValue (TString $ encodeUtf8 n) <> + buildBinaryValue (TI32 s) + + readMessage p = (readMessageBegin p >>=) + where + readMessageBegin p = runParser p $ do + TI32 ver <- parseBinaryValue T_I32 + if ver .&. versionMask /= version1 + then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier" + else do + TString s <- parseBinaryValue T_STRING + TI32 sz <- parseBinaryValue T_I32 + return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz) + + writeVal p = tWrite (getTransport p) . toLazyByteString . buildBinaryValue + readVal p = runParser p . parseBinaryValue +instance Transport t => StatelessProtocol (BinaryProtocol t) where serializeVal _ = toLazyByteString . buildBinaryValue deserializeVal _ ty bs = case LP.eitherResult $ LP.parse (parseBinaryValue ty) bs of Left s -> error s Right val -> val - readVal p = runParser p . parseBinaryValue - -- | Writing Functions buildBinaryValue :: ThriftVal -> Builder buildBinaryValue (TStruct fields) = buildBinaryStruct fields <> buildType T_STOP diff --git a/lib/hs/src/Thrift/Protocol/Compact.hs b/lib/hs/src/Thrift/Protocol/Compact.hs index 07113df21..f23970a82 100644 --- a/lib/hs/src/Thrift/Protocol/Compact.hs +++ b/lib/hs/src/Thrift/Protocol/Compact.hs @@ -25,10 +25,11 @@ module Thrift.Protocol.Compact ( module Thrift.Protocol , CompactProtocol(..) + , parseVarint + , buildVarint ) where import Control.Applicative -import Control.Exception ( throw ) import Control.Monad import Data.Attoparsec.ByteString as P import Data.Attoparsec.ByteString.Lazy as LP @@ -40,7 +41,7 @@ import Data.Monoid import Data.Word import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 ) -import Thrift.Protocol hiding (versionMask) +import Thrift.Protocol import Thrift.Transport import Thrift.Types @@ -64,38 +65,47 @@ typeBits = 0x07 -- 0000 0111 typeShiftAmount :: Int typeShiftAmount = 5 +getTransport :: Transport t => CompactProtocol t -> t +getTransport (CompactProtocol t) = t -instance Protocol CompactProtocol where - getTransport (CompactProtocol t) = t +instance Transport t => Protocol (CompactProtocol t) where + readByte p = tReadAll (getTransport p) 1 + writeMessage p (n, t, s) f = do + tWrite (getTransport p) messageBegin + f + tFlush $ getTransport p + where + messageBegin = toLazyByteString $ + B.word8 protocolID <> + B.word8 ((version .&. versionMask) .|. + (((fromIntegral $ fromEnum t) `shiftL` + typeShiftAmount) .&. typeMask)) <> + buildVarint (i32ToZigZag s) <> + buildCompactValue (TString $ encodeUtf8 n) - writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $ - B.word8 protocolID <> - B.word8 ((version .&. versionMask) .|. - (((fromIntegral $ fromEnum t) `shiftL` - typeShiftAmount) .&. typeMask)) <> - buildVarint (i32ToZigZag s) <> - buildCompactValue (TString $ encodeUtf8 n) - - readMessageBegin p = runParser p $ do - pid <- fromIntegral <$> P.anyWord8 - when (pid /= protocolID) $ error "Bad Protocol ID" - w <- fromIntegral <$> P.anyWord8 - let ver = w .&. versionMask - when (ver /= version) $ error "Bad Protocol version" - let typ = (w `shiftR` typeShiftAmount) .&. typeBits - seqId <- parseVarint zigZagToI32 - TString name <- parseCompactValue T_STRING - return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId) + readMessage p f = readMessageBegin >>= f + where + readMessageBegin = runParser p $ do + pid <- fromIntegral <$> P.anyWord8 + when (pid /= protocolID) $ error "Bad Protocol ID" + w <- fromIntegral <$> P.anyWord8 + let ver = w .&. versionMask + when (ver /= version) $ error "Bad Protocol version" + let typ = (w `shiftR` typeShiftAmount) .&. typeBits + seqId <- parseVarint zigZagToI32 + TString name <- parseCompactValue T_STRING + return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId) + writeVal p = tWrite (getTransport p) . toLazyByteString . buildCompactValue + readVal p ty = runParser p $ parseCompactValue ty + +instance Transport t => StatelessProtocol (CompactProtocol t) where serializeVal _ = toLazyByteString . buildCompactValue deserializeVal _ ty bs = case LP.eitherResult $ LP.parse (parseCompactValue ty) bs of Left s -> error s Right val -> val - readVal p ty = runParser p $ parseCompactValue ty - - -- | Writing Functions buildCompactValue :: ThriftVal -> Builder buildCompactValue (TStruct fields) = buildCompactStruct fields @@ -283,7 +293,7 @@ typeOf v = case v of TSet{} -> 0x0A TMap{} -> 0x0B TStruct{} -> 0x0C - + typeFrom :: Word8 -> ThriftType typeFrom w = case w of 0x01 -> T_BOOL diff --git a/lib/hs/src/Thrift/Protocol/Header.hs b/lib/hs/src/Thrift/Protocol/Header.hs new file mode 100644 index 000000000..5f42db45d --- /dev/null +++ b/lib/hs/src/Thrift/Protocol/Header.hs @@ -0,0 +1,141 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + + +module Thrift.Protocol.Header + ( module Thrift.Protocol + , HeaderProtocol(..) + , getProtocolType + , setProtocolType + , getHeaders + , getWriteHeaders + , setHeader + , setHeaders + , createHeaderProtocol + , createHeaderProtocol1 + ) where + +import Thrift.Protocol +import Thrift.Protocol.Binary +import Thrift.Protocol.JSON +import Thrift.Protocol.Compact +import Thrift.Transport +import Thrift.Transport.Header +import Data.IORef +import qualified Data.Map as Map + +data ProtocolWrap = forall a. (Protocol a) => ProtocolWrap(a) + +instance Protocol ProtocolWrap where + readByte (ProtocolWrap p) = readByte p + readVal (ProtocolWrap p) = readVal p + readMessage (ProtocolWrap p) = readMessage p + writeVal (ProtocolWrap p) = writeVal p + writeMessage (ProtocolWrap p) = writeMessage p + +data HeaderProtocol i o = (Transport i, Transport o) => HeaderProtocol { + trans :: HeaderTransport i o, + wrappedProto :: IORef ProtocolWrap + } + +createProtocolWrap :: Transport t => ProtocolType -> t -> ProtocolWrap +createProtocolWrap typ t = + case typ of + TBinary -> ProtocolWrap $ BinaryProtocol t + TCompact -> ProtocolWrap $ CompactProtocol t + TJSON -> ProtocolWrap $ JSONProtocol t + +createHeaderProtocol :: (Transport i, Transport o) => i -> o -> IO(HeaderProtocol i o) +createHeaderProtocol i o = do + t <- openHeaderTransport i o + pid <- readIORef $ protocolType t + proto <- newIORef $ createProtocolWrap pid t + return $ HeaderProtocol { trans = t, wrappedProto = proto } + +createHeaderProtocol1 :: Transport t => t -> IO(HeaderProtocol t t) +createHeaderProtocol1 t = createHeaderProtocol t t + +resetProtocol :: (Transport i, Transport o) => HeaderProtocol i o -> IO () +resetProtocol p = do + pid <- readIORef $ protocolType $ trans p + writeIORef (wrappedProto p) $ createProtocolWrap pid $ trans p + +getWrapped = readIORef . wrappedProto + +setTransport :: (Transport i, Transport o) => HeaderProtocol i o -> HeaderTransport i o -> HeaderProtocol i o +setTransport p t = p { trans = t } + +updateTransport :: (Transport i, Transport o) => HeaderProtocol i o -> (HeaderTransport i o -> HeaderTransport i o)-> HeaderProtocol i o +updateTransport p f = setTransport p (f $ trans p) + +type Headers = Map.Map String String + +-- TODO: we want to set headers without recreating client... +setHeader :: (Transport i, Transport o) => HeaderProtocol i o -> String -> String -> HeaderProtocol i o +setHeader p k v = updateTransport p $ \t -> t { writeHeaders = Map.insert k v $ writeHeaders t } + +setHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers -> HeaderProtocol i o +setHeaders p h = updateTransport p $ \t -> t { writeHeaders = h } + +-- TODO: make it public once we have first transform implementation for Haskell +setTransforms :: (Transport i, Transport o) => HeaderProtocol i o -> [TransformType] -> HeaderProtocol i o +setTransforms p trs = updateTransport p $ \t -> t { writeTransforms = trs } + +setTransform :: (Transport i, Transport o) => HeaderProtocol i o -> TransformType -> HeaderProtocol i o +setTransform p tr = updateTransport p $ \t -> t { writeTransforms = tr:(writeTransforms t) } + +getWriteHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers +getWriteHeaders = writeHeaders . trans + +getHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> IO [(String, String)] +getHeaders = readIORef . headers . trans + +getProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> IO ProtocolType +getProtocolType p = readIORef $ protocolType $ trans p + +setProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> ProtocolType -> IO () +setProtocolType p typ = do + typ0 <- getProtocolType p + if typ == typ0 + then return () + else do + tSetProtocol (trans p) typ + resetProtocol p + +instance (Transport i, Transport o) => Protocol (HeaderProtocol i o) where + readByte p = tReadAll (trans p) 1 + + readVal p tp = do + proto <- getWrapped p + readVal proto tp + + readMessage p f = do + tResetProtocol (trans p) + resetProtocol p + proto <- getWrapped p + readMessage proto f + + writeVal p v = do + proto <- getWrapped p + writeVal proto v + + writeMessage p x f = do + proto <- getWrapped p + writeMessage proto x f + diff --git a/lib/hs/src/Thrift/Protocol/JSON.hs b/lib/hs/src/Thrift/Protocol/JSON.hs index 7f619e8cb..839eddc84 100644 --- a/lib/hs/src/Thrift/Protocol/JSON.hs +++ b/lib/hs/src/Thrift/Protocol/JSON.hs @@ -29,12 +29,12 @@ module Thrift.Protocol.JSON ) where import Control.Applicative +import Control.Exception (bracket) import Control.Monad import Data.Attoparsec.ByteString as P import Data.Attoparsec.ByteString.Char8 as PC import Data.Attoparsec.ByteString.Lazy as LP import Data.ByteString.Base64.Lazy as B64C -import Data.ByteString.Base64 as B64 import Data.ByteString.Lazy.Builder as B import Data.ByteString.Internal (c2w, w2c) import Data.Functor @@ -58,38 +58,48 @@ import qualified Data.Text.Lazy as LT -- encoded as a JSON 'ByteString' data JSONProtocol t = JSONProtocol t -- ^ Construct a 'JSONProtocol' with a 'Transport' +getTransport :: Transport t => JSONProtocol t -> t +getTransport (JSONProtocol t) = t -instance Protocol JSONProtocol where - getTransport (JSONProtocol t) = t +instance Transport t => Protocol (JSONProtocol t) where + readByte p = tReadAll (getTransport p) 1 - writeMessageBegin (JSONProtocol t) (s, ty, sq) = tWrite t $ toLazyByteString $ - B.char8 '[' <> buildShowable (1 :: Int32) <> - B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <> - B.char8 ',' <> buildShowable (fromEnum ty) <> - B.char8 ',' <> buildShowable sq <> - B.char8 ',' - writeMessageEnd (JSONProtocol t) = tWrite t "]" - readMessageBegin p = runParser p $ skipSpace *> do - _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal) - bs <- lexeme (PC.char8 ',') *> lexeme escapedString - case decodeUtf8' bs of - Left _ -> fail "readMessage: invalid text encoding" - Right str -> do - ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal)) - seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal) - _ <- PC.char8 ',' - return (str, ty, seqNum) - readMessageEnd p = void $ runParser p (PC.char8 ']') + writeMessage (JSONProtocol t) (s, ty, sq) = bracket readMessageBegin readMessageEnd . const + where + readMessageBegin = tWrite t $ toLazyByteString $ + B.char8 '[' <> buildShowable (1 :: Int32) <> + B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <> + B.char8 ',' <> buildShowable (fromEnum ty) <> + B.char8 ',' <> buildShowable sq <> + B.char8 ',' + readMessageEnd _ = do + tWrite t "]" + tFlush t + readMessage p = bracket readMessageBegin readMessageEnd + where + readMessageBegin = runParser p $ skipSpace *> do + _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal) + bs <- lexeme (PC.char8 ',') *> lexeme escapedString + case decodeUtf8' bs of + Left _ -> fail "readMessage: invalid text encoding" + Right str -> do + ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal)) + seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal) + _ <- PC.char8 ',' + return (str, ty, seqNum) + readMessageEnd _ = void $ runParser p (PC.char8 ']') + + writeVal p = tWrite (getTransport p) . toLazyByteString . buildJSONValue + readVal p ty = runParser p $ skipSpace *> parseJSONValue ty + +instance Transport t => StatelessProtocol (JSONProtocol t) where serializeVal _ = toLazyByteString . buildJSONValue deserializeVal _ ty bs = case LP.eitherResult $ LP.parse (parseJSONValue ty) bs of Left s -> error s Right val -> val - readVal p ty = runParser p $ skipSpace *> parseJSONValue ty - - -- Writing Functions buildJSONValue :: ThriftVal -> Builder diff --git a/lib/hs/src/Thrift/Server.hs b/lib/hs/src/Thrift/Server.hs index ed74ceba6..543f33850 100644 --- a/lib/hs/src/Thrift/Server.hs +++ b/lib/hs/src/Thrift/Server.hs @@ -38,10 +38,10 @@ import Thrift.Protocol.Binary -- | A threaded sever that is capable of using any Transport or Protocol -- instances. -runThreadedServer :: (Transport t, Protocol i, Protocol o) - => (Socket -> IO (i t, o t)) +runThreadedServer :: (Protocol i, Protocol o) + => (Socket -> IO (i, o)) -> h - -> (h -> (i t, o t) -> IO Bool) + -> (h -> (i, o) -> IO Bool) -> PortID -> IO a runThreadedServer accepter hand proc_ port = do diff --git a/lib/hs/src/Thrift/Transport/Handle.hs b/lib/hs/src/Thrift/Transport/Handle.hs index b7d16e4fb..ff6295b67 100644 --- a/lib/hs/src/Thrift/Transport/Handle.hs +++ b/lib/hs/src/Thrift/Transport/Handle.hs @@ -44,7 +44,13 @@ import Data.Monoid instance Transport Handle where tIsOpen = hIsOpen tClose = hClose - tRead h n = LBS.hGet h n `Control.Exception.catch` handleEOF mempty + tRead h n = read `Control.Exception.catch` handleEOF mempty + where + read = do + hLookAhead h + LBS.hGetNonBlocking h n + tReadAll _ 0 = return mempty + tReadAll h n = LBS.hGet h n `Control.Exception.catch` throwTransportExn tPeek h = (Just . c2w <$> hLookAhead h) `Control.Exception.catch` handleEOF Nothing tWrite = LBS.hPut tFlush = hFlush @@ -61,8 +67,12 @@ instance HandleSource FilePath where instance HandleSource (HostName, PortID) where hOpen = uncurry connectTo +throwTransportExn :: IOError -> IO a +throwTransportExn e = if isEOFError e + then throw $ TransportExn "Cannot read. Remote side has closed." TE_UNKNOWN + else throw $ TransportExn "Handle tReadAll: Could not read" TE_UNKNOWN handleEOF :: a -> IOError -> IO a handleEOF a e = if isEOFError e then return a - else throw $ TransportExn "TChannelTransport: Could not read" TE_UNKNOWN + else throw $ TransportExn "Handle: Could not read" TE_UNKNOWN diff --git a/lib/hs/src/Thrift/Transport/Header.hs b/lib/hs/src/Thrift/Transport/Header.hs new file mode 100644 index 000000000..2dacad25f --- /dev/null +++ b/lib/hs/src/Thrift/Transport/Header.hs @@ -0,0 +1,354 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Transport.Header + ( module Thrift.Transport + , HeaderTransport(..) + , openHeaderTransport + , ProtocolType(..) + , TransformType(..) + , ClientType(..) + , tResetProtocol + , tSetProtocol + ) where + +import Thrift.Transport +import Thrift.Protocol.Compact +import Control.Applicative +import Control.Exception ( throw ) +import Control.Monad +import Data.Bits +import Data.IORef +import Data.Int +import Data.Monoid +import Data.Word + +import qualified Data.Attoparsec.ByteString as P +import qualified Data.Binary as Binary +import qualified Data.ByteString as BS +import qualified Data.ByteString.Char8 as C +import qualified Data.ByteString.Lazy as LBS +import qualified Data.ByteString.Lazy.Builder as B +import qualified Data.Map as Map + +data ProtocolType = TBinary | TCompact | TJSON deriving (Enum, Eq) +data ClientType = HeaderClient | Framed | Unframed deriving (Enum, Eq) + +infoIdKeyValue = 1 + +type Headers = Map.Map String String + +data TransformType = ZlibTransform deriving (Enum, Eq) + +fromTransportType :: TransformType -> Int16 +fromTransportType ZlibTransform = 1 + +toTransportType :: Int16 -> TransformType +toTransportType 1 = ZlibTransform +toTransportType _ = throw $ TransportExn "HeaderTransport: Unknown transform ID" TE_UNKNOWN + +data HeaderTransport i o = (Transport i, Transport o) => HeaderTransport + { readBuffer :: IORef LBS.ByteString + , writeBuffer :: IORef B.Builder + , inTrans :: i + , outTrans :: o + , clientType :: IORef ClientType + , protocolType :: IORef ProtocolType + , headers :: IORef [(String, String)] + , writeHeaders :: Headers + , transforms :: IORef [TransformType] + , writeTransforms :: [TransformType] + } + +openHeaderTransport :: (Transport i, Transport o) => i -> o -> IO (HeaderTransport i o) +openHeaderTransport i o = do + pid <- newIORef TCompact + rBuf <- newIORef LBS.empty + wBuf <- newIORef mempty + cType <- newIORef HeaderClient + h <- newIORef [] + trans <- newIORef [] + return HeaderTransport + { readBuffer = rBuf + , writeBuffer = wBuf + , inTrans = i + , outTrans = o + , clientType = cType + , protocolType = pid + , headers = h + , writeHeaders = Map.empty + , transforms = trans + , writeTransforms = [] + } + +isFramed t = (/= Unframed) <$> readIORef (clientType t) + +readFrame :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool +readFrame t = do + let input = inTrans t + let rBuf = readBuffer t + let cType = clientType t + lsz <- tRead input 4 + let sz = LBS.toStrict lsz + case P.parseOnly P.endOfInput sz of + Right _ -> do return False + Left _ -> do + case parseBinaryMagic sz of + Right _ -> do + writeIORef rBuf $ lsz + writeIORef cType Unframed + writeIORef (protocolType t) TBinary + return True + Left _ -> do + case parseCompactMagic sz of + Right _ -> do + writeIORef rBuf $ lsz + writeIORef cType Unframed + writeIORef (protocolType t) TCompact + return True + Left _ -> do + let len = Binary.decode lsz :: Int32 + lbuf <- tReadAll input $ fromIntegral len + let buf = LBS.toStrict lbuf + case parseBinaryMagic buf of + Right _ -> do + writeIORef cType Framed + writeIORef (protocolType t) TBinary + writeIORef rBuf lbuf + return True + Left _ -> do + case parseCompactMagic buf of + Right _ -> do + writeIORef cType Framed + writeIORef (protocolType t) TCompact + writeIORef rBuf lbuf + return True + Left _ -> do + case parseHeaderMagic buf of + Right flags -> do + let (flags, seqNum, header, body) = extractHeader buf + writeIORef cType HeaderClient + handleHeader t header + payload <- untransform t body + writeIORef rBuf $ LBS.fromStrict $ payload + return True + Left _ -> + throw $ TransportExn "HeaderTransport: unkonwn client type" TE_UNKNOWN + +parseBinaryMagic = P.parseOnly $ P.word8 0x80 *> P.word8 0x01 *> P.word8 0x00 *> P.anyWord8 +parseCompactMagic = P.parseOnly $ P.word8 0x82 *> P.satisfy (\b -> b .&. 0x1f == 0x01) +parseHeaderMagic = P.parseOnly $ P.word8 0x0f *> P.word8 0xff *> (P.count 2 P.anyWord8) + +parseI32 :: P.Parser Int32 +parseI32 = Binary.decode . LBS.fromStrict <$> P.take 4 +parseI16 :: P.Parser Int16 +parseI16 = Binary.decode . LBS.fromStrict <$> P.take 2 + +extractHeader :: BS.ByteString -> (Int16, Int32, BS.ByteString, BS.ByteString) +extractHeader bs = + case P.parse extractHeader_ bs of + P.Done remain (flags, seqNum, header) -> (flags, seqNum, header, remain) + _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN + where + extractHeader_ = do + magic <- P.word8 0x0f *> P.word8 0xff + flags <- parseI16 + seqNum <- parseI32 + (headerSize :: Int) <- (* 4) . fromIntegral <$> parseI16 + header <- P.take headerSize + return (flags, seqNum, header) + +handleHeader t header = + case P.parseOnly parseHeader header of + Right (pType, trans, info) -> do + writeIORef (protocolType t) pType + writeIORef (transforms t) trans + writeIORef (headers t) info + _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN + + +iw16 :: Int16 -> Word16 +iw16 = fromIntegral +iw32 :: Int32 -> Word32 +iw32 = fromIntegral +wi16 :: Word16 -> Int16 +wi16 = fromIntegral +wi32 :: Word32 -> Int32 +wi32 = fromIntegral + +parseHeader :: P.Parser (ProtocolType, [TransformType], [(String, String)]) +parseHeader = do + protocolType <- toProtocolType <$> parseVarint wi16 + numTrans <- fromIntegral <$> parseVarint wi16 + trans <- replicateM numTrans parseTransform + info <- parseInfo + return (protocolType, trans, info) + +toProtocolType :: Int16 -> ProtocolType +toProtocolType 0 = TBinary +toProtocolType 1 = TJSON +toProtocolType 2 = TCompact + +fromProtocolType :: ProtocolType -> Int16 +fromProtocolType TBinary = 0 +fromProtocolType TJSON = 1 +fromProtocolType TCompact = 2 + +parseTransform :: P.Parser TransformType +parseTransform = toTransportType <$> parseVarint wi16 + +parseInfo :: P.Parser [(String, String)] +parseInfo = do + n <- P.eitherP P.endOfInput (parseVarint wi32) + case n of + Left _ -> return [] + Right n0 -> + replicateM (fromIntegral n0) $ do + klen <- parseVarint wi16 + k <- P.take $ fromIntegral klen + vlen <- parseVarint wi16 + v <- P.take $ fromIntegral vlen + return (C.unpack k, C.unpack v) + +parseString :: P.Parser BS.ByteString +parseString = parseVarint wi32 >>= (P.take . fromIntegral) + +buildHeader :: HeaderTransport i o -> IO B.Builder +buildHeader t = do + pType <- readIORef $ protocolType t + let pId = buildVarint $ iw16 $ fromProtocolType pType + let headerContent = pId <> (buildTransforms t) <> (buildInfo t) + let len = fromIntegral $ LBS.length $ B.toLazyByteString headerContent + -- TODO: length limit check + let padding = mconcat $ replicate (mod len 4) $ B.word8 0 + let codedLen = B.int16BE (fromIntegral $ (quot (len - 1) 4) + 1) + let flags = 0 + let seqNum = 0 + return $ B.int16BE 0x0fff <> B.int16BE flags <> B.int32BE seqNum <> codedLen <> headerContent <> padding + +buildTransforms :: HeaderTransport i o -> B.Builder +-- TODO: check length limit +buildTransforms t = + let trans = writeTransforms t in + (buildVarint $ iw16 $ fromIntegral $ length trans) <> + (mconcat $ map (buildVarint . iw16 . fromTransportType) trans) + +buildInfo :: HeaderTransport i o -> B.Builder +buildInfo t = + let h = Map.assocs $ writeHeaders t in + -- TODO: check length limit + case length h of + 0 -> mempty + len -> (buildVarint $ iw16 $ fromIntegral $ len) <> (mconcat $ map buildInfoEntry h) + where + buildInfoEntry (k, v) = buildVarStr k <> buildVarStr v + -- TODO: check length limit + buildVarStr s = (buildVarint $ iw16 $ fromIntegral $ length s) <> B.string8 s + +tResetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool +tResetProtocol t = do + rBuf <- readIORef $ readBuffer t + writeIORef (clientType t) HeaderClient + readFrame t + +tSetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> ProtocolType -> IO () +tSetProtocol t = writeIORef (protocolType t) + +transform :: HeaderTransport i o -> LBS.ByteString -> LBS.ByteString +transform t bs = + foldr applyTransform bs $ writeTransforms t + where + -- applyTransform bs ZlibTransform = + -- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN + applyTransform bs _ = + throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN + +untransform :: HeaderTransport i o -> BS.ByteString -> IO BS.ByteString +untransform t bs = do + trans <- readIORef $ transforms t + return $ foldl unapplyTransform bs trans + where + -- unapplyTransform bs ZlibTransform = + -- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN + unapplyTransform bs _ = + throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN + +instance (Transport i, Transport o) => Transport (HeaderTransport i o) where + tIsOpen t = do + tIsOpen (inTrans t) + tIsOpen (outTrans t) + + tClose t = do + tClose(outTrans t) + tClose(inTrans t) + + tRead t len = do + rBuf <- readIORef $ readBuffer t + if not $ LBS.null rBuf + then do + let (consumed, remain) = LBS.splitAt (fromIntegral len) rBuf + writeIORef (readBuffer t) remain + return consumed + else do + framed <- isFramed t + if not framed + then tRead (inTrans t) len + else do + ok <- readFrame t + if ok + then tRead t len + else return LBS.empty + + tPeek t = do + rBuf <- readIORef (readBuffer t) + if not $ LBS.null rBuf + then return $ Just $ LBS.head rBuf + else do + framed <- isFramed t + if not framed + then tPeek (inTrans t) + else do + ok <- readFrame t + if ok + then tPeek t + else return Nothing + + tWrite t buf = do + let wBuf = writeBuffer t + framed <- isFramed t + if framed + then modifyIORef wBuf (<> B.lazyByteString buf) + else + -- TODO: what should we do when switched to unframed in the middle ? + tWrite(outTrans t) buf + + tFlush t = do + cType <- readIORef $ clientType t + case cType of + Unframed -> tFlush $ outTrans t + Framed -> flushBuffer t id mempty + HeaderClient -> buildHeader t >>= flushBuffer t (transform t) + where + flushBuffer t f header = do + wBuf <- readIORef $ writeBuffer t + writeIORef (writeBuffer t) mempty + let payload = B.toLazyByteString (header <> wBuf) + tWrite (outTrans t) $ Binary.encode (fromIntegral $ LBS.length payload :: Int32) + tWrite (outTrans t) $ f payload + tFlush (outTrans t) diff --git a/lib/hs/thrift.cabal b/lib/hs/thrift.cabal index 9754ab2ee..583067953 100644 --- a/lib/hs/thrift.cabal +++ b/lib/hs/thrift.cabal @@ -49,6 +49,7 @@ Library Thrift, Thrift.Arbitraries Thrift.Protocol, + Thrift.Protocol.Header, Thrift.Protocol.Binary, Thrift.Protocol.Compact, Thrift.Protocol.JSON, @@ -57,6 +58,7 @@ Library Thrift.Transport.Empty, Thrift.Transport.Framed, Thrift.Transport.Handle, + Thrift.Transport.Header, Thrift.Transport.HttpClient, Thrift.Transport.IOBuffer, Thrift.Transport.Memory, diff --git a/test/hs/TestClient.hs b/test/hs/TestClient.hs index d1ebb3cd0..93fb591b3 100644 --- a/test/hs/TestClient.hs +++ b/test/hs/TestClient.hs @@ -46,6 +46,7 @@ import Thrift.Transport.HttpClient import Thrift.Protocol import Thrift.Protocol.Binary import Thrift.Protocol.Compact +import Thrift.Protocol.Header import Thrift.Protocol.JSON data Options = Options @@ -85,12 +86,14 @@ getTransport t host port = do return (NoTransport $ "Unsupported transport: " ++ data ProtocolType = Binary | Compact | JSON + | Header deriving (Show, Eq) getProtocol :: String -> ProtocolType getProtocol "binary" = Binary getProtocol "compact" = Compact getProtocol "json" = JSON +getProtocol "header" = Header getProtocol p = error $ "Unsupported Protocol: " ++ p defaultOptions :: Options @@ -104,7 +107,7 @@ defaultOptions = Options , testLoops = 1 } -runClient :: (Protocol p, Transport t) => p t -> IO () +runClient :: Protocol p => p -> IO () runClient p = do let prot = (p,p) putStrLn "Starting Tests" @@ -266,6 +269,7 @@ main = do Binary -> runClient $ BinaryProtocol t Compact -> runClient $ CompactProtocol t JSON -> runClient $ JSONProtocol t + Header -> createHeaderProtocol t t >>= runClient runTest loops p t = do let client = makeClient p t replicateM_ loops client diff --git a/test/hs/TestServer.hs b/test/hs/TestServer.hs index 4a88649b8..b7731ab1c 100644 --- a/test/hs/TestServer.hs +++ b/test/hs/TestServer.hs @@ -48,6 +48,7 @@ import Thrift.Transport.Framed import Thrift.Transport.Handle import Thrift.Protocol.Binary import Thrift.Protocol.Compact +import Thrift.Protocol.Header import Thrift.Protocol.JSON data Options = Options @@ -90,11 +91,13 @@ getTransport t = NoTransport $ "Unsupported transport: " ++ t data ProtocolType = Binary | Compact | JSON + | Header getProtocol :: String -> ProtocolType getProtocol "binary" = Binary getProtocol "compact" = Compact getProtocol "json" = JSON +getProtocol "header" = Header getProtocol p = error $"Unsupported Protocol: " ++ p defaultOptions :: Options @@ -261,13 +264,19 @@ main = do t <- f socket return (p t, p t) + headerAcceptor f socket = do + t <- f socket + p <- createHeaderProtocol1 t + return (p, p) + doRunServer p f = do runThreadedServer (acceptor p f) TestHandler ThriftTest.process . PortNumber . fromIntegral runServer p f port = case p of - Binary -> do doRunServer BinaryProtocol f port - Compact -> do doRunServer CompactProtocol f port - JSON -> do doRunServer JSONProtocol f port + Binary -> doRunServer BinaryProtocol f port + Compact -> doRunServer CompactProtocol f port + JSON -> doRunServer JSONProtocol f port + Header -> runThreadedServer (headerAcceptor f) TestHandler ThriftTest.process (PortNumber $ fromIntegral port) parseFlags :: [String] -> Options -> Maybe Options parseFlags (flag : flags) opts = do diff --git a/test/known_failures_Linux.json b/test/known_failures_Linux.json index c96198808..754535f12 100644 --- a/test/known_failures_Linux.json +++ b/test/known_failures_Linux.json @@ -229,6 +229,8 @@ "go-java_json_http-ip", "go-java_json_http-ip-ssl", "go-nodejs_json_framed-ip", + "hs-csharp_binary_framed-ip", + "hs-csharp_compact_framed-ip", "hs-dart_binary_framed-ip", "hs-dart_compact_framed-ip", "hs-dart_json_framed-ip", @@ -331,4 +333,4 @@ "rs-dart_compact_framed-ip", "rs-dart_multi-binary_framed-ip", "rs-dart_multic-compact_framed-ip" -]
\ No newline at end of file +] diff --git a/test/tests.json b/test/tests.json index 35d0a6cc1..c4e07eefb 100644 --- a/test/tests.json +++ b/test/tests.json @@ -216,6 +216,7 @@ "ip" ], "protocols": [ + "header", "compact", "binary", "json" |