summaryrefslogtreecommitdiff
path: root/lib/py/src/protocol/THeaderProtocol.py
blob: 4b58e639da265dac212915ec092cc1b9a6cbf4a8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory
from thrift.Thrift import TApplicationException, TMessageType
from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType


PROTOCOLS_BY_ID = {
    THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
    THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
}


class THeaderProtocol(TProtocolBase):
    """A framed protocol with headers and payload transforms.

    THeaderProtocol frames other Thrift protocols and adds support for optional
    out-of-band headers. The currently supported subprotocols are
    TBinaryProtocol and TCompactProtocol. When used as a client, the
    subprotocol to frame can be chosen with the `default_protocol` parameter to
    the constructor.

    It's also possible to apply transforms to the encoded message payload. The
    only transform currently supported is to gzip.

    When used in a server, THeaderProtocol can accept messages from
    non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
    includes framed and unframed transports and both TBinaryProtocol and
    TCompactProtocol. The server will respond in the appropriate dialect for
    the connected client. HTTP clients are not currently supported.

    THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
    or TProcessPoolServer.

    See doc/specs/HeaderFormat.md for details of the wire format.

    """

    def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY):
        # much of the actual work for THeaderProtocol happens down in
        # THeaderTransport since we need to do low-level shenanigans to detect
        # if the client is sending us headers or one of the headerless formats
        # we support. this wraps the real transport with the one that does all
        # the magic.
        if not isinstance(transport, THeaderTransport):
            transport = THeaderTransport(transport, allowed_client_types, default_protocol)
        super(THeaderProtocol, self).__init__(transport)
        self._set_protocol()

    def get_headers(self):
        return self.trans.get_headers()

    def set_header(self, key, value):
        self.trans.set_header(key, value)

    def clear_headers(self):
        self.trans.clear_headers()

    def add_transform(self, transform_id):
        self.trans.add_transform(transform_id)

    def writeMessageBegin(self, name, ttype, seqid):
        self.trans.sequence_id = seqid
        return self._protocol.writeMessageBegin(name, ttype, seqid)

    def writeMessageEnd(self):
        return self._protocol.writeMessageEnd()

    def writeStructBegin(self, name):
        return self._protocol.writeStructBegin(name)

    def writeStructEnd(self):
        return self._protocol.writeStructEnd()

    def writeFieldBegin(self, name, ttype, fid):
        return self._protocol.writeFieldBegin(name, ttype, fid)

    def writeFieldEnd(self):
        return self._protocol.writeFieldEnd()

    def writeFieldStop(self):
        return self._protocol.writeFieldStop()

    def writeMapBegin(self, ktype, vtype, size):
        return self._protocol.writeMapBegin(ktype, vtype, size)

    def writeMapEnd(self):
        return self._protocol.writeMapEnd()

    def writeListBegin(self, etype, size):
        return self._protocol.writeListBegin(etype, size)

    def writeListEnd(self):
        return self._protocol.writeListEnd()

    def writeSetBegin(self, etype, size):
        return self._protocol.writeSetBegin(etype, size)

    def writeSetEnd(self):
        return self._protocol.writeSetEnd()

    def writeBool(self, bool_val):
        return self._protocol.writeBool(bool_val)

    def writeByte(self, byte):
        return self._protocol.writeByte(byte)

    def writeI16(self, i16):
        return self._protocol.writeI16(i16)

    def writeI32(self, i32):
        return self._protocol.writeI32(i32)

    def writeI64(self, i64):
        return self._protocol.writeI64(i64)

    def writeDouble(self, dub):
        return self._protocol.writeDouble(dub)

    def writeBinary(self, str_val):
        return self._protocol.writeBinary(str_val)

    def _set_protocol(self):
        try:
            protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
        except KeyError:
            raise TApplicationException(
                TProtocolException.INVALID_PROTOCOL,
                "Unknown protocol requested.",
            )

        self._protocol = protocol_cls(self.trans)
        self._fast_encode = self._protocol._fast_encode
        self._fast_decode = self._protocol._fast_decode

    def readMessageBegin(self):
        try:
            self.trans.readFrame(0)
            self._set_protocol()
        except TApplicationException as exc:
            self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
            exc.write(self._protocol)
            self._protocol.writeMessageEnd()
            self.trans.flush()

        return self._protocol.readMessageBegin()

    def readMessageEnd(self):
        return self._protocol.readMessageEnd()

    def readStructBegin(self):
        return self._protocol.readStructBegin()

    def readStructEnd(self):
        return self._protocol.readStructEnd()

    def readFieldBegin(self):
        return self._protocol.readFieldBegin()

    def readFieldEnd(self):
        return self._protocol.readFieldEnd()

    def readMapBegin(self):
        return self._protocol.readMapBegin()

    def readMapEnd(self):
        return self._protocol.readMapEnd()

    def readListBegin(self):
        return self._protocol.readListBegin()

    def readListEnd(self):
        return self._protocol.readListEnd()

    def readSetBegin(self):
        return self._protocol.readSetBegin()

    def readSetEnd(self):
        return self._protocol.readSetEnd()

    def readBool(self):
        return self._protocol.readBool()

    def readByte(self):
        return self._protocol.readByte()

    def readI16(self):
        return self._protocol.readI16()

    def readI32(self):
        return self._protocol.readI32()

    def readI64(self):
        return self._protocol.readI64()

    def readDouble(self):
        return self._protocol.readDouble()

    def readBinary(self):
        return self._protocol.readBinary()


class THeaderProtocolFactory(TProtocolFactory):
    def __init__(
        self,
        allowed_client_types=(THeaderClientType.HEADERS,),
        default_protocol=THeaderSubprotocolID.BINARY,
    ):
        self.allowed_client_types = allowed_client_types
        self.default_protocol = default_protocol

    def getProtocol(self, trans):
        return THeaderProtocol(trans, self.allowed_client_types, self.default_protocol)