summaryrefslogtreecommitdiff
path: root/wsme/rest/args.py
blob: 4574cdde736c979c98f7706bc07377c33959484f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import cgi
import datetime
import re

from simplegeneric import generic

from wsme.exc import ClientSideError, UnknownArgument, InvalidInput

from wsme.types import iscomplex, list_attributes, Unset
from wsme.types import UserType, ArrayType, DictType, File
from wsme.utils import parse_isodate, parse_isotime, parse_isodatetime
import wsme.runtime

from six import moves

ARRAY_MAX_SIZE = 1000


@generic
def from_param(datatype, value):
    return datatype(value) if value else None


@from_param.when_object(datetime.date)
def date_from_param(datatype, value):
    return parse_isodate(value) if value else None


@from_param.when_object(datetime.time)
def time_from_param(datatype, value):
    return parse_isotime(value) if value else None


@from_param.when_object(datetime.datetime)
def datetime_from_param(datatype, value):
    return parse_isodatetime(value) if value else None


@from_param.when_object(File)
def filetype_from_param(datatype, value):
    if isinstance(value, cgi.FieldStorage):
        return File(fieldstorage=value)
    return File(content=value)


@from_param.when_type(UserType)
def usertype_from_param(datatype, value):
    return datatype.frombasetype(
        from_param(datatype.basetype, value))


@from_param.when_type(ArrayType)
def array_from_param(datatype, value):
    if value is None:
        return value
    return [
        from_param(datatype.item_type, item)
        for item in value
    ]


@generic
def from_params(datatype, params, path, hit_paths):
    if iscomplex(datatype) and datatype is not File:
        objfound = False
        for key in params:
            if key.startswith(path + '.'):
                objfound = True
                break
        if objfound:
            r = datatype()
            for attrdef in list_attributes(datatype):
                value = from_params(
                    attrdef.datatype,
                    params, '%s.%s' % (path, attrdef.key), hit_paths
                )
                if value is not Unset:
                    setattr(r, attrdef.key, value)
            return r
    else:
        if path in params:
            hit_paths.add(path)
            return from_param(datatype, params[path])
    return Unset


@from_params.when_type(ArrayType)
def array_from_params(datatype, params, path, hit_paths):
    if hasattr(params, 'getall'):
        # webob multidict
        def getall(params, path):
            return params.getall(path)
    elif hasattr(params, 'getlist'):
        # werkzeug multidict
        def getall(params, path):  # noqa
            return params.getlist(path)
    if path in params:
        hit_paths.add(path)
        return [
            from_param(datatype.item_type, value)
            for value in getall(params, path)]

    if iscomplex(datatype.item_type):
        attributes = set()
        r = re.compile('^%s\.(?P<attrname>[^\.])' % re.escape(path))
        for p in params.keys():
            m = r.match(p)
            if m:
                attributes.add(m.group('attrname'))
        if attributes:
            value = []
            for attrdef in list_attributes(datatype.item_type):
                attrpath = '%s.%s' % (path, attrdef.key)
                hit_paths.add(attrpath)
                attrvalues = getall(params, attrpath)
                if len(value) < len(attrvalues):
                    value[-1:] = [
                        datatype.item_type()
                        for i in moves.range(len(attrvalues) - len(value))
                    ]
                for i, attrvalue in enumerate(attrvalues):
                    setattr(
                        value[i],
                        attrdef.key,
                        from_param(attrdef.datatype, attrvalue)
                    )
            return value

    indexes = set()
    r = re.compile('^%s\[(?P<index>\d+)\]' % re.escape(path))

    for p in params.keys():
        m = r.match(p)
        if m:
            indexes.add(int(m.group('index')))

    if not indexes:
        return Unset

    indexes = list(indexes)
    indexes.sort()

    return [from_params(datatype.item_type, params,
                        '%s[%s]' % (path, index), hit_paths)
            for index in indexes]


@from_params.when_type(DictType)
def dict_from_params(datatype, params, path, hit_paths):

    keys = set()
    r = re.compile('^%s\[(?P<key>[a-zA-Z0-9_\.]+)\]' % re.escape(path))

    for p in params.keys():
        m = r.match(p)
        if m:
            keys.add(from_param(datatype.key_type, m.group('key')))

    if not keys:
        return Unset

    return dict((
        (key, from_params(datatype.value_type,
                          params, '%s[%s]' % (path, key), hit_paths))
        for key in keys))


def args_from_args(funcdef, args, kwargs):
    newargs = []
    for argdef, arg in zip(funcdef.arguments[:len(args)], args):
        try:
            newargs.append(from_param(argdef.datatype, arg))
        except Exception:
            if isinstance(argdef.datatype, UserType):
                datatype_name = argdef.datatype.name
            else:
                datatype_name = argdef.datatype.__name__
            raise InvalidInput(
                argdef.name,
                arg,
                "unable to convert to %s" % datatype_name)
    newkwargs = {}
    for argname, value in kwargs.items():
        newkwargs[argname] = from_param(
            funcdef.get_arg(argname).datatype, value
        )
    return newargs, newkwargs


def args_from_params(funcdef, params):
    kw = {}
    hit_paths = set()
    for argdef in funcdef.arguments:
        value = from_params(
            argdef.datatype, params, argdef.name, hit_paths)
        if value is not Unset:
            kw[argdef.name] = value
    paths = set(params.keys())
    unknown_paths = paths - hit_paths
    if '__body__' in unknown_paths:
        unknown_paths.remove('__body__')
    if not funcdef.ignore_extra_args and unknown_paths:
        raise UnknownArgument(', '.join(unknown_paths))
    return [], kw


def args_from_body(funcdef, body, mimetype):
    from wsme.rest import json as restjson
    from wsme.rest import xml as restxml

    if funcdef.body_type is not None:
        datatypes = {funcdef.arguments[-1].name: funcdef.body_type}
    else:
        datatypes = dict(((a.name, a.datatype) for a in funcdef.arguments))

    if not body:
        return (), {}
    if mimetype == "application/x-www-form-urlencoded":
        # the parameters should have been parsed in params
        return (), {}
    elif mimetype in restjson.accept_content_types:
        dataformat = restjson
    elif mimetype in restxml.accept_content_types:
        dataformat = restxml
    else:
        raise ValueError("Unknow mimetype: %s" % mimetype)

    try:
        kw = dataformat.parse(
            body, datatypes, bodyarg=funcdef.body_type is not None
        )
    except UnknownArgument:
        if not funcdef.ignore_extra_args:
            raise

    return (), kw


def combine_args(funcdef, akw, allow_override=False):
    newargs, newkwargs = [], {}
    for args, kwargs in akw:
        for i, arg in enumerate(args):
            n = funcdef.arguments[i].name
            if not allow_override and n in newkwargs:
                raise ClientSideError(
                    "Parameter %s was given several times" % n)
            newkwargs[n] = arg
        for name, value in kwargs.items():
            n = str(name)
            if not allow_override and n in newkwargs:
                raise ClientSideError(
                    "Parameter %s was given several times" % n)
            newkwargs[n] = value
    return newargs, newkwargs


def get_args(funcdef, args, kwargs, params, form, body, mimetype):
    """Combine arguments from :
    * the host framework args and kwargs
    * the request params
    * the request body

    Note that the host framework args and kwargs can be overridden
    by arguments from params of body
    """
    # get the body from params if not given directly
    if not body and '__body__' in params:
        body = params['__body__']

    # extract args from the host args and kwargs
    from_args = args_from_args(funcdef, args, kwargs)

    # extract args from the request parameters
    from_params = args_from_params(funcdef, params)

    # extract args from the form parameters
    if form:
        from_form_params = args_from_params(funcdef, form)
    else:
        from_form_params = (), {}

    # extract args from the request body
    from_body = args_from_body(funcdef, body, mimetype)

    # combine params and body arguments
    from_params_and_body = combine_args(
        funcdef,
        (from_params, from_form_params, from_body)
    )

    args, kwargs = combine_args(
        funcdef,
        (from_args, from_params_and_body),
        allow_override=True
    )
    wsme.runtime.check_arguments(funcdef, args, kwargs)
    return args, kwargs