summaryrefslogtreecommitdiff
path: root/websockify
diff options
context:
space:
mode:
Diffstat (limited to 'websockify')
-rw-r--r--websockify/auth_plugins.py50
-rw-r--r--websockify/websocket.py12
-rwxr-xr-xwebsockify/websocketproxy.py33
3 files changed, 81 insertions, 14 deletions
diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py
index 647c26e..924d5de 100644
--- a/websockify/auth_plugins.py
+++ b/websockify/auth_plugins.py
@@ -7,7 +7,15 @@ class BasePlugin(object):
class AuthenticationError(Exception):
- pass
+ def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
+ self.code = response_code
+ self.headers = response_headers
+ self.msg = response_msg
+
+ if log_msg is None:
+ log_msg = response_msg
+
+ super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg))
class InvalidOriginError(AuthenticationError):
@@ -16,8 +24,44 @@ class InvalidOriginError(AuthenticationError):
self.actual_origin = actual
super(InvalidOriginError, self).__init__(
- "Invalid Origin Header: Expected one of "
- "%s, got '%s'" % (expected, actual))
+ response_msg='Invalid Origin',
+ log_msg="Invalid Origin Header: Expected one of "
+ "%s, got '%s'" % (expected, actual))
+
+
+class BasicHTTPAuth(object):
+ def __init__(self, src=None):
+ self.src = src
+
+ def authenticate(self, headers, target_host, target_port):
+ import base64
+
+ auth_header = headers.get('Authorization')
+ if auth_header:
+ if not auth_header.startswith('Basic '):
+ raise AuthenticationError(response_code=403)
+
+ try:
+ user_pass_raw = base64.b64decode(auth_header[6:])
+ except TypeError:
+ raise AuthenticationError(response_code=403)
+
+ user_pass = user_pass_raw.split(':', 1)
+ if len(user_pass) != 2:
+ raise AuthenticationError(response_code=403)
+
+ if not self.validate_creds:
+ raise AuthenticationError(response_code=403)
+
+ else:
+ raise AuthenticationError(response_code=401,
+ response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})
+
+ def validate_creds(username, password):
+ if '%s:%s' % (username, password) == self.src:
+ return True
+ else:
+ return False
class ExpectOrigin(object):
def __init__(self, src=None):
diff --git a/websockify/websocket.py b/websockify/websocket.py
index 1cbf583..7fa9651 100644
--- a/websockify/websocket.py
+++ b/websockify/websocket.py
@@ -474,9 +474,13 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
"""Upgrade a connection to Websocket, if requested. If this succeeds,
new_websocket_client() will be called. Otherwise, False is returned.
"""
+
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
+ # ensure connection is authorized, and determine the target
+ self.validate_connection()
+
if not self.do_websocket_handshake():
return False
@@ -549,6 +553,10 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
""" Do something with a WebSockets client connection. """
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
+ def validate_connection(self):
+ """ Ensure that the connection is a valid connection, and set the target. """
+ pass
+
def do_HEAD(self):
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
@@ -789,7 +797,7 @@ class WebSocketServer(object):
"""
ready = select.select([sock], [], [], 3)[0]
-
+
if not ready:
raise self.EClose("ignoring socket not ready")
# Peek, but do not read the data so that we have a opportunity
@@ -903,7 +911,7 @@ class WebSocketServer(object):
def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """
- # handler process
+ # handler process
client = None
try:
try:
diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py
index 029b6f3..46ab545 100755
--- a/websockify/websocketproxy.py
+++ b/websockify/websocketproxy.py
@@ -18,6 +18,7 @@ try: from http.server import HTTPServer
except: from BaseHTTPServer import HTTPServer
import select
from websockify import websocket
+from websockify import auth_plugins as auth
try:
from urllib.parse import parse_qs, urlparse
except:
@@ -37,20 +38,34 @@ Traffic Legend:
< - Client send
<. - Client send partial
"""
+
+ def send_auth_error(self, ex):
+ self.send_response(ex.code, ex.msg)
+ self.send_header('Content-Type', 'text/html')
+ for name, val in ex.headers.items():
+ self.send_header(name, val)
+
+ self.end_headers()
+
+ def validate_connection(self):
+ if self.server.token_plugin:
+ (self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
+
+ if self.server.auth_plugin:
+ try:
+ self.server.auth_plugin.authenticate(
+ headers=self.headers, target_host=self.server.target_host,
+ target_port=self.server.target_port)
+ except auth.AuthenticationError:
+ ex = sys.exc_info()[1]
+ self.send_auth_error(ex)
+ raise
def new_websocket_client(self):
"""
Called after a new WebSocket connection has been established.
"""
- # Checks if we receive a token, and look
- # for a valid target for it then
- if self.server.token_plugin:
- (self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)
-
- if self.server.auth_plugin:
- self.server.auth_plugin.authenticate(
- headers=self.headers, target_host=self.server.target_host,
- target_port=self.server.target_port)
+ # Checking for a token is done in validate_connection()
# Connect to the target
if self.server.wrap_cmd: