summaryrefslogtreecommitdiff
path: root/tests/testutils/http_server.py
blob: 6ecb7b5b3beadafcfe5470241f8c766e3edc1f3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import multiprocessing
import os
import posixpath
import html
import threading
import base64
from http.server import SimpleHTTPRequestHandler, HTTPServer, HTTPStatus


class Unauthorized(Exception):
    pass


class RequestHandler(SimpleHTTPRequestHandler):

    def get_root_dir(self):
        authorization = self.headers.get('authorization')
        if not authorization:
            if not self.server.anonymous_dir:
                raise Unauthorized('unauthorized')
            return self.server.anonymous_dir
        else:
            authorization = authorization.split()
            if len(authorization) != 2 or authorization[0].lower() != 'basic':
                raise Unauthorized('unauthorized')
            try:
                decoded = base64.decodebytes(authorization[1].encode('ascii'))
                user, password = decoded.decode('ascii').split(':')
                expected_password, directory = self.server.users[user]
                if password == expected_password:
                    return directory
            except:                           # noqa
                raise Unauthorized('unauthorized')
            return None

    def unauthorized(self):
        shortmsg, longmsg = self.responses[HTTPStatus.UNAUTHORIZED]
        self.send_response(HTTPStatus.UNAUTHORIZED, shortmsg)
        self.send_header('Connection', 'close')

        content = (self.error_message_format % {
            'code': HTTPStatus.UNAUTHORIZED,
            'message': html.escape(longmsg, quote=False),
            'explain': html.escape(longmsg, quote=False)
        })
        body = content.encode('UTF-8', 'replace')
        self.send_header('Content-Type', self.error_content_type)
        self.send_header('Content-Length', str(len(body)))
        self.send_header('WWW-Authenticate', 'Basic realm="{}"'.format(self.server.realm))
        self.end_headers()
        self.end_headers()

        if self.command != 'HEAD' and body:
            self.wfile.write(body)

    def do_GET(self):
        try:
            super().do_GET()
        except Unauthorized:
            self.unauthorized()

    def do_HEAD(self):
        try:
            super().do_HEAD()
        except Unauthorized:
            self.unauthorized()

    def translate_path(self, path):
        path = path.split('?', 1)[0]
        path = path.split('#', 1)[0]
        path = posixpath.normpath(path)
        assert posixpath.isabs(path)
        path = posixpath.relpath(path, '/')
        return os.path.join(self.get_root_dir(), path)


class AuthHTTPServer(HTTPServer):
    def __init__(self, *args, **kwargs):
        self.users = {}
        self.anonymous_dir = None
        self.realm = 'Realm'
        super().__init__(*args, **kwargs)


class SimpleHttpServer(multiprocessing.Process):
    def __init__(self):
        self.__stop = multiprocessing.Queue()
        super().__init__()
        self.server = AuthHTTPServer(('127.0.0.1', 0), RequestHandler)
        self.started = False

    def start(self):
        self.started = True
        super().start()

    def run(self):
        t = threading.Thread(target=self.server.serve_forever)
        t.start()
        self.__stop.get()
        self.server.shutdown()
        t.join()

    def stop(self):
        if not self.started:
            return
        self.__stop.put(None)
        self.terminate()
        self.join()
        self.__stop.close()
        self.__stop.join_thread()

    def allow_anonymous(self, cwd):
        self.server.anonymous_dir = cwd

    def add_user(self, user, password, cwd):
        self.server.users[user] = (password, cwd)

    def base_url(self):
        return 'http://127.0.0.1:{}'.format(self.server.server_port)