diff options
author | Ian Lance Taylor <ian@gcc.gnu.org> | 2011-10-26 23:57:58 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2011-10-26 23:57:58 +0000 |
commit | d8f412571f8768df2d3239e72392dfeabbad1559 (patch) | |
tree | 19d182df05ead7ff8ba7ee00a7d57555e1383fdf /libgo/go/websocket | |
parent | e0c39d66d4f0607177b1cf8995dda56a667e07b3 (diff) | |
download | gcc-d8f412571f8768df2d3239e72392dfeabbad1559.tar.gz |
Update Go library to last weekly.
From-SVN: r180552
Diffstat (limited to 'libgo/go/websocket')
-rw-r--r-- | libgo/go/websocket/client.go | 309 | ||||
-rw-r--r-- | libgo/go/websocket/hixie.go | 696 | ||||
-rw-r--r-- | libgo/go/websocket/hixie_test.go | 201 | ||||
-rw-r--r-- | libgo/go/websocket/hybi.go | 550 | ||||
-rw-r--r-- | libgo/go/websocket/hybi_test.go | 584 | ||||
-rw-r--r-- | libgo/go/websocket/server.go | 210 | ||||
-rw-r--r-- | libgo/go/websocket/websocket.go | 455 | ||||
-rw-r--r-- | libgo/go/websocket/websocket_test.go | 196 |
8 files changed, 2573 insertions, 628 deletions
diff --git a/libgo/go/websocket/client.go b/libgo/go/websocket/client.go index 74bede4249f..b7eaafda163 100644 --- a/libgo/go/websocket/client.go +++ b/libgo/go/websocket/client.go @@ -6,114 +6,119 @@ package websocket import ( "bufio" - "bytes" "crypto/tls" - "fmt" - "http" "io" "net" "os" - "rand" - "strings" "url" ) -type ProtocolError struct { - ErrorString string -} - -func (err *ProtocolError) String() string { return string(err.ErrorString) } - -var ( - ErrBadScheme = &ProtocolError{"bad scheme"} - ErrBadStatus = &ProtocolError{"bad status"} - ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} - ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} - ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"} - ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"} - ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"} - secKeyRandomChars [0x30 - 0x21 + 0x7F - 0x3A]byte -) - +// DialError is an error that occurs while dialling a websocket server. type DialError struct { - URL string - Protocol string - Origin string - Error os.Error + *Config + Error os.Error } func (e *DialError) String() string { - return "websocket.Dial " + e.URL + ": " + e.Error.String() + return "websocket.Dial " + e.Config.Location.String() + ": " + e.Error.String() } -func init() { - i := 0 - for ch := byte(0x21); ch < 0x30; ch++ { - secKeyRandomChars[i] = ch - i++ +// NewConfig creates a new WebSocket config for client connection. +func NewConfig(server, origin string) (config *Config, err os.Error) { + config = new(Config) + config.Version = ProtocolVersionHybi13 + config.Location, err = url.ParseRequest(server) + if err != nil { + return } - for ch := byte(0x3a); ch < 0x7F; ch++ { - secKeyRandomChars[i] = ch - i++ + config.Origin, err = url.ParseRequest(origin) + if err != nil { + return } + return } -type handshaker func(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) os.Error - -// newClient creates a new Web Socket client connection. -func newClient(resourceName, host, origin, location, protocol string, rwc io.ReadWriteCloser, handshake handshaker) (ws *Conn, err os.Error) { +// NewClient creates a new WebSocket client connection over rwc. +func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err os.Error) { br := bufio.NewReader(rwc) bw := bufio.NewWriter(rwc) - err = handshake(resourceName, host, origin, location, protocol, br, bw) + switch config.Version { + case ProtocolVersionHixie75: + err = hixie75ClientHandshake(config, br, bw) + case ProtocolVersionHixie76, ProtocolVersionHybi00: + err = hixie76ClientHandshake(config, br, bw) + case ProtocolVersionHybi08, ProtocolVersionHybi13: + err = hybiClientHandshake(config, br, bw) + default: + err = ErrBadProtocolVersion + } if err != nil { return } buf := bufio.NewReadWriter(br, bw) - ws = newConn(origin, location, protocol, buf, rwc) + switch config.Version { + case ProtocolVersionHixie75, ProtocolVersionHixie76, ProtocolVersionHybi00: + ws = newHixieClientConn(config, buf, rwc) + case ProtocolVersionHybi08, ProtocolVersionHybi13: + ws = newHybiClientConn(config, buf, rwc) + } return } /* -Dial opens a new client connection to a Web Socket. +Dial opens a new client connection to a WebSocket. A trivial example client: package main import ( - "websocket" + "http" + "log" "strings" + "websocket" ) func main() { - ws, err := websocket.Dial("ws://localhost/ws", "", "http://localhost/"); - if err != nil { - panic("Dial: " + err.String()) + origin := "http://localhost/" + url := "ws://localhost/ws" + ws, err := websocket.Dial(url, "", origin) + if err != nil { + log.Fatal(err) } if _, err := ws.Write([]byte("hello, world!\n")); err != nil { - panic("Write: " + err.String()) + log.Fatal(err) } var msg = make([]byte, 512); if n, err := ws.Read(msg); err != nil { - panic("Read: " + err.String()) + log.Fatal(err) } // use msg[0:n] } */ func Dial(url_, protocol, origin string) (ws *Conn, err os.Error) { - var client net.Conn - - parsedUrl, err := url.Parse(url_) + config, err := NewConfig(url_, origin) if err != nil { - goto Error + return nil, err } + return DialConfig(config) +} - switch parsedUrl.Scheme { +// DialConfig opens a new client connection to a WebSocket with a config. +func DialConfig(config *Config) (ws *Conn, err os.Error) { + var client net.Conn + if config.Location == nil { + return nil, &DialError{config, ErrBadWebSocketLocation} + } + if config.Origin == nil { + return nil, &DialError{config, ErrBadWebSocketOrigin} + } + switch config.Location.Scheme { case "ws": - client, err = net.Dial("tcp", parsedUrl.Host) + client, err = net.Dial("tcp", config.Location.Host) case "wss": - client, err = tls.Dial("tcp", parsedUrl.Host, nil) + client, err = tls.Dial("tcp", config.Location.Host, config.TlsConfig) default: err = ErrBadScheme @@ -122,202 +127,12 @@ func Dial(url_, protocol, origin string) (ws *Conn, err os.Error) { goto Error } - ws, err = newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url_, protocol, client, handshake) + ws, err = NewClient(config, client) if err != nil { goto Error } return Error: - return nil, &DialError{url_, protocol, origin, err} -} - -/* -Generates handshake key as described in 4.1 Opening handshake step 16 to 22. -cf. http://www.whatwg.org/specs/web-socket-protocol/ -*/ -func generateKeyNumber() (key string, number uint32) { - // 16. Let /spaces_n/ be a random integer from 1 to 12 inclusive. - spaces := rand.Intn(12) + 1 - - // 17. Let /max_n/ be the largest integer not greater than - // 4,294,967,295 divided by /spaces_n/ - max := int(4294967295 / uint32(spaces)) - - // 18. Let /number_n/ be a random integer from 0 to /max_n/ inclusive. - number = uint32(rand.Intn(max + 1)) - - // 19. Let /product_n/ be the result of multiplying /number_n/ and - // /spaces_n/ together. - product := number * uint32(spaces) - - // 20. Let /key_n/ be a string consisting of /product_n/, expressed - // in base ten using the numerals in the range U+0030 DIGIT ZERO (0) - // to U+0039 DIGIT NINE (9). - key = fmt.Sprintf("%d", product) - - // 21. Insert between one and twelve random characters from the ranges - // U+0021 to U+002F and U+003A to U+007E into /key_n/ at random - // positions. - n := rand.Intn(12) + 1 - for i := 0; i < n; i++ { - pos := rand.Intn(len(key)) + 1 - ch := secKeyRandomChars[rand.Intn(len(secKeyRandomChars))] - key = key[0:pos] + string(ch) + key[pos:] - } - - // 22. Insert /spaces_n/ U+0020 SPACE characters into /key_n/ at random - // positions other than the start or end of the string. - for i := 0; i < spaces; i++ { - pos := rand.Intn(len(key)-1) + 1 - key = key[0:pos] + " " + key[pos:] - } - - return -} - -/* -Generates handshake key_3 as described in 4.1 Opening handshake step 26. -cf. http://www.whatwg.org/specs/web-socket-protocol/ -*/ -func generateKey3() (key []byte) { - // 26. Let /key3/ be a string consisting of eight random bytes (or - // equivalently, a random 64 bit integer encoded in big-endian order). - key = make([]byte, 8) - for i := 0; i < 8; i++ { - key[i] = byte(rand.Intn(256)) - } - return -} - -/* -Web Socket protocol handshake based on -http://www.whatwg.org/specs/web-socket-protocol/ -(draft of http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol) -*/ -func handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { - // 4.1. Opening handshake. - // Step 5. send a request line. - bw.WriteString("GET " + resourceName + " HTTP/1.1\r\n") - - // Step 6-14. push request headers in fields. - var fields []string - fields = append(fields, "Upgrade: WebSocket\r\n") - fields = append(fields, "Connection: Upgrade\r\n") - fields = append(fields, "Host: "+host+"\r\n") - fields = append(fields, "Origin: "+origin+"\r\n") - if protocol != "" { - fields = append(fields, "Sec-WebSocket-Protocol: "+protocol+"\r\n") - } - // TODO(ukai): Step 15. send cookie if any. - - // Step 16-23. generate keys and push Sec-WebSocket-Key<n> in fields. - key1, number1 := generateKeyNumber() - key2, number2 := generateKeyNumber() - fields = append(fields, "Sec-WebSocket-Key1: "+key1+"\r\n") - fields = append(fields, "Sec-WebSocket-Key2: "+key2+"\r\n") - - // Step 24. shuffle fields and send them out. - for i := 1; i < len(fields); i++ { - j := rand.Intn(i) - fields[i], fields[j] = fields[j], fields[i] - } - for i := 0; i < len(fields); i++ { - bw.WriteString(fields[i]) - } - // Step 25. send CRLF. - bw.WriteString("\r\n") - - // Step 26. generate 8 bytes random key. - key3 := generateKey3() - // Step 27. send it out. - bw.Write(key3) - if err = bw.Flush(); err != nil { - return - } - - // Step 28-29, 32-40. read response from server. - resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) - if err != nil { - return err - } - // Step 30. check response code is 101. - if resp.StatusCode != 101 { - return ErrBadStatus - } - - // Step 41. check websocket headers. - if resp.Header.Get("Upgrade") != "WebSocket" || - strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { - return ErrBadUpgrade - } - - if resp.Header.Get("Sec-Websocket-Origin") != origin { - return ErrBadWebSocketOrigin - } - - if resp.Header.Get("Sec-Websocket-Location") != location { - return ErrBadWebSocketLocation - } - - if protocol != "" && resp.Header.Get("Sec-Websocket-Protocol") != protocol { - return ErrBadWebSocketProtocol - } - - // Step 42-43. get expected data from challenge data. - expected, err := getChallengeResponse(number1, number2, key3) - if err != nil { - return err - } - - // Step 44. read 16 bytes from server. - reply := make([]byte, 16) - if _, err = io.ReadFull(br, reply); err != nil { - return err - } - - // Step 45. check the reply equals to expected data. - if !bytes.Equal(expected, reply) { - return ErrChallengeResponse - } - // WebSocket connection is established. - return -} - -/* -Handshake described in (soon obsolete) -draft-hixie-thewebsocket-protocol-75. -*/ -func draft75handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { - bw.WriteString("GET " + resourceName + " HTTP/1.1\r\n") - bw.WriteString("Upgrade: WebSocket\r\n") - bw.WriteString("Connection: Upgrade\r\n") - bw.WriteString("Host: " + host + "\r\n") - bw.WriteString("Origin: " + origin + "\r\n") - if protocol != "" { - bw.WriteString("WebSocket-Protocol: " + protocol + "\r\n") - } - bw.WriteString("\r\n") - bw.Flush() - resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) - if err != nil { - return - } - if resp.Status != "101 Web Socket Protocol Handshake" { - return ErrBadStatus - } - if resp.Header.Get("Upgrade") != "WebSocket" || - resp.Header.Get("Connection") != "Upgrade" { - return ErrBadUpgrade - } - if resp.Header.Get("Websocket-Origin") != origin { - return ErrBadWebSocketOrigin - } - if resp.Header.Get("Websocket-Location") != location { - return ErrBadWebSocketLocation - } - if protocol != "" && resp.Header.Get("Websocket-Protocol") != protocol { - return ErrBadWebSocketProtocol - } - return + return nil, &DialError{config, err} } diff --git a/libgo/go/websocket/hixie.go b/libgo/go/websocket/hixie.go new file mode 100644 index 00000000000..841ff3c3ef5 --- /dev/null +++ b/libgo/go/websocket/hixie.go @@ -0,0 +1,696 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +// This file implements a protocol of Hixie draft version 75 and 76 +// (draft 76 equals to hybi 00) + +import ( + "bufio" + "bytes" + "crypto/md5" + "encoding/binary" + "fmt" + "http" + "io" + "io/ioutil" + "os" + "rand" + "strconv" + "strings" + "url" +) + +// An aray of characters to be randomly inserted to construct Sec-WebSocket-Key +// value. It holds characters from ranges U+0021 to U+002F and U+003A to U+007E. +// See Step 21 in Section 4.1 Opening handshake. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00#page-22 +var secKeyRandomChars [0x30 - 0x21 + 0x7F - 0x3A]byte + +func init() { + i := 0 + for ch := byte(0x21); ch < 0x30; ch++ { + secKeyRandomChars[i] = ch + i++ + } + for ch := byte(0x3a); ch < 0x7F; ch++ { + secKeyRandomChars[i] = ch + i++ + } +} + +type byteReader interface { + ReadByte() (byte, os.Error) +} + +// readHixieLength reads frame length for frame type 0x80-0xFF +// as defined in Hixie draft. +// See section 4.2 Data framing. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00#section-4.2 +func readHixieLength(r byteReader) (length int64, lengthFields []byte, err os.Error) { + for { + c, err := r.ReadByte() + if err != nil { + return 0, nil, err + } + lengthFields = append(lengthFields, c) + length = length*128 + int64(c&0x7f) + if c&0x80 == 0 { + break + } + } + return +} + +// A hixieLengthFrameReader is a reader for frame type 0x80-0xFF +// as defined in hixie draft. +type hixieLengthFrameReader struct { + reader io.Reader + FrameType byte + Length int64 + header *bytes.Buffer + length int +} + +func (frame *hixieLengthFrameReader) Read(msg []byte) (n int, err os.Error) { + return frame.reader.Read(msg) +} + +func (frame *hixieLengthFrameReader) PayloadType() byte { + if frame.FrameType == '\xff' && frame.Length == 0 { + return CloseFrame + } + return UnknownFrame +} + +func (frame *hixieLengthFrameReader) HeaderReader() io.Reader { + if frame.header == nil { + return nil + } + if frame.header.Len() == 0 { + frame.header = nil + return nil + } + return frame.header +} + +func (frame *hixieLengthFrameReader) TrailerReader() io.Reader { return nil } + +func (frame *hixieLengthFrameReader) Len() (n int) { return frame.length } + +// A HixieSentinelFrameReader is a reader for frame type 0x00-0x7F +// as defined in hixie draft. +type hixieSentinelFrameReader struct { + reader *bufio.Reader + FrameType byte + header *bytes.Buffer + data []byte + seenTrailer bool + trailer *bytes.Buffer +} + +func (frame *hixieSentinelFrameReader) Read(msg []byte) (n int, err os.Error) { + if len(frame.data) == 0 { + if frame.seenTrailer { + return 0, os.EOF + } + frame.data, err = frame.reader.ReadSlice('\xff') + if err == nil { + frame.seenTrailer = true + frame.data = frame.data[:len(frame.data)-1] // trim \xff + frame.trailer = bytes.NewBuffer([]byte{0xff}) + } + } + n = copy(msg, frame.data) + frame.data = frame.data[n:] + return n, err +} + +func (frame *hixieSentinelFrameReader) PayloadType() byte { + if frame.FrameType == 0 { + return TextFrame + } + return UnknownFrame +} + +func (frame *hixieSentinelFrameReader) HeaderReader() io.Reader { + if frame.header == nil { + return nil + } + if frame.header.Len() == 0 { + frame.header = nil + return nil + } + return frame.header +} + +func (frame *hixieSentinelFrameReader) TrailerReader() io.Reader { + if frame.trailer == nil { + return nil + } + if frame.trailer.Len() == 0 { + frame.trailer = nil + return nil + } + return frame.trailer +} + +func (frame *hixieSentinelFrameReader) Len() int { return -1 } + +// A HixieFrameReaderFactory creates new frame reader based on its frame type. +type hixieFrameReaderFactory struct { + *bufio.Reader +} + +func (buf hixieFrameReaderFactory) NewFrameReader() (r frameReader, err os.Error) { + var header []byte + var b byte + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + if b&0x80 == 0x80 { + length, lengthFields, err := readHixieLength(buf.Reader) + if err != nil { + return nil, err + } + if length == 0 { + return nil, os.EOF + } + header = append(header, lengthFields...) + return &hixieLengthFrameReader{ + reader: io.LimitReader(buf.Reader, length), + FrameType: b, + Length: length, + header: bytes.NewBuffer(header)}, err + } + return &hixieSentinelFrameReader{ + reader: buf.Reader, + FrameType: b, + header: bytes.NewBuffer(header)}, err +} + +type hixiFrameWriter struct { + writer *bufio.Writer +} + +func (frame *hixiFrameWriter) Write(msg []byte) (n int, err os.Error) { + frame.writer.WriteByte(0) + frame.writer.Write(msg) + frame.writer.WriteByte(0xff) + err = frame.writer.Flush() + return len(msg), err +} + +func (frame *hixiFrameWriter) Close() os.Error { return nil } + +type hixiFrameWriterFactory struct { + *bufio.Writer +} + +func (buf hixiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err os.Error) { + if payloadType != TextFrame { + return nil, ErrNotSupported + } + return &hixiFrameWriter{writer: buf.Writer}, nil +} + +type hixiFrameHandler struct { + conn *Conn +} + +func (handler *hixiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err os.Error) { + if header := frame.HeaderReader(); header != nil { + io.Copy(ioutil.Discard, header) + } + if frame.PayloadType() != TextFrame { + io.Copy(ioutil.Discard, frame) + return nil, nil + } + return frame, nil +} + +func (handler *hixiFrameHandler) WriteClose(_ int) (err os.Error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + closingFrame := []byte{'\xff', '\x00'} + handler.conn.buf.Write(closingFrame) + return handler.conn.buf.Flush() +} + +// newHixiConn creates a new WebSocket connection speaking hixie draft protocol. +func newHixieConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + if buf == nil { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + buf = bufio.NewReadWriter(br, bw) + } + ws := &Conn{config: config, request: request, buf: buf, rwc: rwc, + frameReaderFactory: hixieFrameReaderFactory{buf.Reader}, + frameWriterFactory: hixiFrameWriterFactory{buf.Writer}, + PayloadType: TextFrame} + ws.frameHandler = &hixiFrameHandler{ws} + return ws +} + +// getChallengeResponse computes the expected response from the +// challenge as described in section 5.1 Opening Handshake steps 42 to +// 43 of http://www.whatwg.org/specs/web-socket-protocol/ +func getChallengeResponse(number1, number2 uint32, key3 []byte) (expected []byte, err os.Error) { + // 41. Let /challenge/ be the concatenation of /number_1/, expressed + // a big-endian 32 bit integer, /number_2/, expressed in a big- + // endian 32 bit integer, and the eight bytes of /key_3/ in the + // order they were sent to the wire. + challenge := make([]byte, 16) + binary.BigEndian.PutUint32(challenge[0:], number1) + binary.BigEndian.PutUint32(challenge[4:], number2) + copy(challenge[8:], key3) + + // 42. Let /expected/ be the MD5 fingerprint of /challenge/ as a big- + // endian 128 bit string. + h := md5.New() + if _, err = h.Write(challenge); err != nil { + return + } + expected = h.Sum() + return +} + +// Generates handshake key as described in 4.1 Opening handshake step 16 to 22. +// cf. http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 +func generateKeyNumber() (key string, number uint32) { + // 16. Let /spaces_n/ be a random integer from 1 to 12 inclusive. + spaces := rand.Intn(12) + 1 + + // 17. Let /max_n/ be the largest integer not greater than + // 4,294,967,295 divided by /spaces_n/ + max := int(4294967295 / uint32(spaces)) + + // 18. Let /number_n/ be a random integer from 0 to /max_n/ inclusive. + number = uint32(rand.Intn(max + 1)) + + // 19. Let /product_n/ be the result of multiplying /number_n/ and + // /spaces_n/ together. + product := number * uint32(spaces) + + // 20. Let /key_n/ be a string consisting of /product_n/, expressed + // in base ten using the numerals in the range U+0030 DIGIT ZERO (0) + // to U+0039 DIGIT NINE (9). + key = fmt.Sprintf("%d", product) + + // 21. Insert between one and twelve random characters from the ranges + // U+0021 to U+002F and U+003A to U+007E into /key_n/ at random + // positions. + n := rand.Intn(12) + 1 + for i := 0; i < n; i++ { + pos := rand.Intn(len(key)) + 1 + ch := secKeyRandomChars[rand.Intn(len(secKeyRandomChars))] + key = key[0:pos] + string(ch) + key[pos:] + } + + // 22. Insert /spaces_n/ U+0020 SPACE characters into /key_n/ at random + // positions other than the start or end of the string. + for i := 0; i < spaces; i++ { + pos := rand.Intn(len(key)-1) + 1 + key = key[0:pos] + " " + key[pos:] + } + + return +} + +// Generates handshake key_3 as described in 4.1 Opening handshake step 26. +// cf. http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 +func generateKey3() (key []byte) { + // 26. Let /key3/ be a string consisting of eight random bytes (or + // equivalently, a random 64 bit integer encoded in big-endian order). + key = make([]byte, 8) + for i := 0; i < 8; i++ { + key[i] = byte(rand.Intn(256)) + } + return +} + +// Cilent handhake described in (soon obsolete) +// draft-ietf-hybi-thewebsocket-protocol-00 +// (draft-hixie-thewebsocket-protocol-76) +func hixie76ClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { + switch config.Version { + case ProtocolVersionHixie76, ProtocolVersionHybi00: + default: + panic("wrong protocol version.") + } + // 4.1. Opening handshake. + // Step 5. send a request line. + bw.WriteString("GET " + config.Location.RawPath + " HTTP/1.1\r\n") + + // Step 6-14. push request headers in fields. + fields := []string{ + "Upgrade: WebSocket\r\n", + "Connection: Upgrade\r\n", + "Host: " + config.Location.Host + "\r\n", + "Origin: " + config.Origin.String() + "\r\n", + } + if len(config.Protocol) > 0 { + if len(config.Protocol) != 1 { + return ErrBadWebSocketProtocol + } + fields = append(fields, "Sec-WebSocket-Protocol: "+config.Protocol[0]+"\r\n") + } + // TODO(ukai): Step 15. send cookie if any. + + // Step 16-23. generate keys and push Sec-WebSocket-Key<n> in fields. + key1, number1 := generateKeyNumber() + key2, number2 := generateKeyNumber() + if config.handshakeData != nil { + key1 = config.handshakeData["key1"] + n, err := strconv.Atoui(config.handshakeData["number1"]) + if err != nil { + panic(err) + } + number1 = uint32(n) + key2 = config.handshakeData["key2"] + n, err = strconv.Atoui(config.handshakeData["number2"]) + if err != nil { + panic(err) + } + number2 = uint32(n) + } + fields = append(fields, "Sec-WebSocket-Key1: "+key1+"\r\n") + fields = append(fields, "Sec-WebSocket-Key2: "+key2+"\r\n") + + // Step 24. shuffle fields and send them out. + for i := 1; i < len(fields); i++ { + j := rand.Intn(i) + fields[i], fields[j] = fields[j], fields[i] + } + for i := 0; i < len(fields); i++ { + bw.WriteString(fields[i]) + } + // Step 25. send CRLF. + bw.WriteString("\r\n") + + // Step 26. generate 8 bytes random key. + key3 := generateKey3() + if config.handshakeData != nil { + key3 = []byte(config.handshakeData["key3"]) + } + // Step 27. send it out. + bw.Write(key3) + if err = bw.Flush(); err != nil { + return + } + + // Step 28-29, 32-40. read response from server. + resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) + if err != nil { + return err + } + // Step 30. check response code is 101. + if resp.StatusCode != 101 { + return ErrBadStatus + } + + // Step 41. check websocket headers. + if resp.Header.Get("Upgrade") != "WebSocket" || + strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { + return ErrBadUpgrade + } + + if resp.Header.Get("Sec-Websocket-Origin") != config.Origin.String() { + return ErrBadWebSocketOrigin + } + + if resp.Header.Get("Sec-Websocket-Location") != config.Location.String() { + return ErrBadWebSocketLocation + } + + if len(config.Protocol) > 0 && resp.Header.Get("Sec-Websocket-Protocol") != config.Protocol[0] { + return ErrBadWebSocketProtocol + } + + // Step 42-43. get expected data from challenge data. + expected, err := getChallengeResponse(number1, number2, key3) + if err != nil { + return err + } + + // Step 44. read 16 bytes from server. + reply := make([]byte, 16) + if _, err = io.ReadFull(br, reply); err != nil { + return err + } + + // Step 45. check the reply equals to expected data. + if !bytes.Equal(expected, reply) { + return ErrChallengeResponse + } + // WebSocket connection is established. + return +} + +// Client Handshake described in (soon obsolete) +// draft-hixie-thewebsocket-protocol-75. +func hixie75ClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { + if config.Version != ProtocolVersionHixie75 { + panic("wrong protocol version.") + } + bw.WriteString("GET " + config.Location.RawPath + " HTTP/1.1\r\n") + bw.WriteString("Upgrade: WebSocket\r\n") + bw.WriteString("Connection: Upgrade\r\n") + bw.WriteString("Host: " + config.Location.Host + "\r\n") + bw.WriteString("Origin: " + config.Origin.String() + "\r\n") + if len(config.Protocol) > 0 { + if len(config.Protocol) != 1 { + return ErrBadWebSocketProtocol + } + bw.WriteString("WebSocket-Protocol: " + config.Protocol[0] + "\r\n") + } + bw.WriteString("\r\n") + bw.Flush() + resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) + if err != nil { + return + } + if resp.Status != "101 Web Socket Protocol Handshake" { + return ErrBadStatus + } + if resp.Header.Get("Upgrade") != "WebSocket" || + resp.Header.Get("Connection") != "Upgrade" { + return ErrBadUpgrade + } + if resp.Header.Get("Websocket-Origin") != config.Origin.String() { + return ErrBadWebSocketOrigin + } + if resp.Header.Get("Websocket-Location") != config.Location.String() { + return ErrBadWebSocketLocation + } + if len(config.Protocol) > 0 && resp.Header.Get("Websocket-Protocol") != config.Protocol[0] { + return ErrBadWebSocketProtocol + } + return +} + +// newHixieClientConn returns new WebSocket connection speaking hixie draft protocol. +func newHixieClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { + return newHixieConn(config, buf, rwc, nil) +} + +// Gets key number from Sec-WebSocket-Key<n>: field as described +// in 5.2 Sending the server's opening handshake, 4. +func getKeyNumber(s string) (r uint32) { + // 4. Let /key-number_n/ be the digits (characters in the range + // U+0030 DIGIT ZERO (0) to U+0039 DIGIT NINE (9)) in /key_1/, + // interpreted as a base ten integer, ignoring all other characters + // in /key_n/. + r = 0 + for i := 0; i < len(s); i++ { + if s[i] >= '0' && s[i] <= '9' { + r = r*10 + uint32(s[i]) - '0' + } + } + return +} + +// A Hixie76ServerHandshaker performs a server handshake using +// hixie draft 76 protocol. +type hixie76ServerHandshaker struct { + *Config + challengeResponse []byte +} + +func (c *hixie76ServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err os.Error) { + c.Version = ProtocolVersionHybi00 + if req.Method != "GET" { + return http.StatusMethodNotAllowed, ErrBadRequestMethod + } + // HTTP version can be safely ignored. + + if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || + strings.ToLower(req.Header.Get("Connection")) != "upgrade" { + return http.StatusBadRequest, ErrNotWebSocket + } + + // TODO(ukai): check Host + c.Origin, err = url.ParseRequest(req.Header.Get("Origin")) + if err != nil { + return http.StatusBadRequest, err + } + + key1 := req.Header.Get("Sec-Websocket-Key1") + if key1 == "" { + return http.StatusBadRequest, ErrChallengeResponse + } + key2 := req.Header.Get("Sec-Websocket-Key2") + if key2 == "" { + return http.StatusBadRequest, ErrChallengeResponse + } + key3 := make([]byte, 8) + if _, err := io.ReadFull(buf, key3); err != nil { + return http.StatusBadRequest, ErrChallengeResponse + } + + var scheme string + if req.TLS != nil { + scheme = "wss" + } else { + scheme = "ws" + } + c.Location, err = url.ParseRequest(scheme + "://" + req.Host + req.URL.RawPath) + if err != nil { + return http.StatusBadRequest, err + } + + // Step 4. get key number in Sec-WebSocket-Key<n> fields. + keyNumber1 := getKeyNumber(key1) + keyNumber2 := getKeyNumber(key2) + + // Step 5. get number of spaces in Sec-WebSocket-Key<n> fields. + space1 := uint32(strings.Count(key1, " ")) + space2 := uint32(strings.Count(key2, " ")) + if space1 == 0 || space2 == 0 { + return http.StatusBadRequest, ErrChallengeResponse + } + + // Step 6. key number must be an integral multiple of spaces. + if keyNumber1%space1 != 0 || keyNumber2%space2 != 0 { + return http.StatusBadRequest, ErrChallengeResponse + } + + // Step 7. let part be key number divided by spaces. + part1 := keyNumber1 / space1 + part2 := keyNumber2 / space2 + + // Step 8. let challenge be concatenation of part1, part2 and key3. + // Step 9. get MD5 fingerprint of challenge. + c.challengeResponse, err = getChallengeResponse(part1, part2, key3) + if err != nil { + return http.StatusInternalServerError, err + } + protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) + protocols := strings.Split(protocol, ",") + for i := 0; i < len(protocols); i++ { + c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) + } + + return http.StatusSwitchingProtocols, nil +} + +func (c *hixie76ServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err os.Error) { + if len(c.Protocol) > 0 { + if len(c.Protocol) != 1 { + return ErrBadWebSocketProtocol + } + } + + // Step 10. send response status line. + buf.WriteString("HTTP/1.1 101 WebSocket Protocol Handshake\r\n") + // Step 11. send response headers. + buf.WriteString("Upgrade: WebSocket\r\n") + buf.WriteString("Connection: Upgrade\r\n") + buf.WriteString("Sec-WebSocket-Origin: " + c.Origin.String() + "\r\n") + buf.WriteString("Sec-WebSocket-Location: " + c.Location.String() + "\r\n") + if len(c.Protocol) > 0 { + buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") + } + // Step 12. send CRLF. + buf.WriteString("\r\n") + // Step 13. send response data. + buf.Write(c.challengeResponse) + return buf.Flush() +} + +func (c *hixie76ServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn) { + return newHixieServerConn(c.Config, buf, rwc, request) +} + +// A hixie75ServerHandshaker performs a server handshake using +// hixie draft 75 protocol. +type hixie75ServerHandshaker struct { + *Config +} + +func (c *hixie75ServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err os.Error) { + c.Version = ProtocolVersionHixie75 + if req.Method != "GET" || req.Proto != "HTTP/1.1" { + return http.StatusMethodNotAllowed, ErrBadRequestMethod + } + if req.Header.Get("Upgrade") != "WebSocket" { + return http.StatusBadRequest, ErrNotWebSocket + } + if req.Header.Get("Connection") != "Upgrade" { + return http.StatusBadRequest, ErrNotWebSocket + } + c.Origin, err = url.ParseRequest(strings.TrimSpace(req.Header.Get("Origin"))) + if err != nil { + return http.StatusBadRequest, err + } + + var scheme string + if req.TLS != nil { + scheme = "wss" + } else { + scheme = "ws" + } + c.Location, err = url.ParseRequest(scheme + "://" + req.Host + req.URL.RawPath) + if err != nil { + return http.StatusBadRequest, err + } + protocol := strings.TrimSpace(req.Header.Get("Websocket-Protocol")) + protocols := strings.Split(protocol, ",") + for i := 0; i < len(protocols); i++ { + c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) + } + + return http.StatusSwitchingProtocols, nil +} + +func (c *hixie75ServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err os.Error) { + if len(c.Protocol) > 0 { + if len(c.Protocol) != 1 { + return ErrBadWebSocketProtocol + } + } + + buf.WriteString("HTTP/1.1 101 Web Socket Protocol Handshake\r\n") + buf.WriteString("Upgrade: WebSocket\r\n") + buf.WriteString("Connection: Upgrade\r\n") + buf.WriteString("WebSocket-Origin: " + c.Origin.String() + "\r\n") + buf.WriteString("WebSocket-Location: " + c.Location.String() + "\r\n") + if len(c.Protocol) > 0 { + buf.WriteString("WebSocket-Protocol: " + c.Protocol[0] + "\r\n") + } + buf.WriteString("\r\n") + return buf.Flush() +} + +func (c *hixie75ServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn) { + return newHixieServerConn(c.Config, buf, rwc, request) +} + +// newHixieServerConn returns a new WebSocket connection speaking hixie draft protocol. +func newHixieServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHixieConn(config, buf, rwc, request) +} diff --git a/libgo/go/websocket/hixie_test.go b/libgo/go/websocket/hixie_test.go new file mode 100644 index 00000000000..98a0de4d6f4 --- /dev/null +++ b/libgo/go/websocket/hixie_test.go @@ -0,0 +1,201 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "bytes" + "fmt" + "http" + "os" + "strings" + "testing" + "url" +) + +// Test the getChallengeResponse function with values from section +// 5.1 of the specification steps 18, 26, and 43 from +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 +func TestHixie76Challenge(t *testing.T) { + var part1 uint32 = 777007543 + var part2 uint32 = 114997259 + key3 := []byte{0x47, 0x30, 0x22, 0x2D, 0x5A, 0x3F, 0x47, 0x58} + expected := []byte("0st3Rl&q-2ZU^weu") + + response, err := getChallengeResponse(part1, part2, key3) + if err != nil { + t.Errorf("getChallengeResponse: returned error %v", err) + return + } + if !bytes.Equal(expected, response) { + t.Errorf("getChallengeResponse: expected %q got %q", expected, response) + } +} + +func TestHixie76ClientHandshake(t *testing.T) { + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 WebSocket Protocol Handshake +Upgrade: WebSocket +Connection: Upgrade +Sec-WebSocket-Origin: http://example.com +Sec-WebSocket-Location: ws://example.com/demo +Sec-WebSocket-Protocol: sample + +8jKS'y:G*Co,Wxa-`)) + + var err os.Error + config := new(Config) + config.Location, err = url.ParseRequest("ws://example.com/demo") + if err != nil { + t.Fatal("location url", err) + } + config.Origin, err = url.ParseRequest("http://example.com") + if err != nil { + t.Fatal("origin url", err) + } + config.Protocol = append(config.Protocol, "sample") + config.Version = ProtocolVersionHixie76 + + config.handshakeData = map[string]string{ + "key1": "4 @1 46546xW%0l 1 5", + "number1": "829309203", + "key2": "12998 5 Y3 1 .P00", + "number2": "259970620", + "key3": "^n:ds[4U", + } + err = hixie76ClientHandshake(config, br, bw) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + req, err := http.ReadRequest(bufio.NewReader(b)) + if err != nil { + t.Fatalf("read request: %v", err) + } + if req.Method != "GET" { + t.Errorf("request method expected GET, but got %q", req.Method) + } + if req.URL.Path != "/demo" { + t.Errorf("request path expected /demo, but got %q", req.URL.Path) + } + if req.Proto != "HTTP/1.1" { + t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto) + } + if req.Host != "example.com" { + t.Errorf("request Host expected example.com, but got %v", req.Host) + } + var expectedHeader = map[string]string{ + "Connection": "Upgrade", + "Upgrade": "WebSocket", + "Origin": "http://example.com", + "Sec-Websocket-Key1": config.handshakeData["key1"], + "Sec-Websocket-Key2": config.handshakeData["key2"], + "Sec-WebSocket-Protocol": config.Protocol[0], + } + for k, v := range expectedHeader { + if req.Header.Get(k) != v { + t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k))) + } + } +} + +func TestHixie76ServerHandshake(t *testing.T) { + config := new(Config) + handshaker := &hixie76ServerHandshaker{Config: config} + br := bufio.NewReader(strings.NewReader(`GET /demo HTTP/1.1 +Host: example.com +Connection: Upgrade +Sec-WebSocket-Key2: 12998 5 Y3 1 .P00 +Sec-WebSocket-Protocol: sample +Upgrade: WebSocket +Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5 +Origin: http://example.com + +^n:ds[4U`)) + req, err := http.ReadRequest(br) + if err != nil { + t.Fatal("request", err) + } + code, err := handshaker.ReadHandshake(br, req) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + if code != http.StatusSwitchingProtocols { + t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) + } + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + + err = handshaker.AcceptHandshake(bw) + if err != nil { + t.Errorf("handshake response failed: %v", err) + } + expectedResponse := strings.Join([]string{ + "HTTP/1.1 101 WebSocket Protocol Handshake", + "Upgrade: WebSocket", + "Connection: Upgrade", + "Sec-WebSocket-Origin: http://example.com", + "Sec-WebSocket-Location: ws://example.com/demo", + "Sec-WebSocket-Protocol: sample", + "", ""}, "\r\n") + "8jKS'y:G*Co,Wxa-" + if b.String() != expectedResponse { + t.Errorf("handshake expected %q but got %q", expectedResponse, b.String()) + } +} + +func TestHixie76SkipLengthFrame(t *testing.T) { + b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'} + buf := bytes.NewBuffer(b) + br := bufio.NewReader(buf) + bw := bufio.NewWriter(buf) + config := newConfig(t, "/") + ws := newHixieConn(config, bufio.NewReadWriter(br, bw), nil, nil) + msg := make([]byte, 5) + n, err := ws.Read(msg) + if err != nil { + t.Errorf("Read: %v", err) + } + if !bytes.Equal(b[4:9], msg[0:n]) { + t.Errorf("Read: expected %q got %q", b[4:9], msg[0:n]) + } +} + +func TestHixie76SkipNoUTF8Frame(t *testing.T) { + b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'} + buf := bytes.NewBuffer(b) + br := bufio.NewReader(buf) + bw := bufio.NewWriter(buf) + config := newConfig(t, "/") + ws := newHixieConn(config, bufio.NewReadWriter(br, bw), nil, nil) + msg := make([]byte, 5) + n, err := ws.Read(msg) + if err != nil { + t.Errorf("Read: %v", err) + } + if !bytes.Equal(b[4:9], msg[0:n]) { + t.Errorf("Read: expected %q got %q", b[4:9], msg[0:n]) + } +} + +func TestHixie76ClosingFrame(t *testing.T) { + b := []byte{0, 'h', 'e', 'l', 'l', 'o', '\xff'} + buf := bytes.NewBuffer(b) + br := bufio.NewReader(buf) + bw := bufio.NewWriter(buf) + config := newConfig(t, "/") + ws := newHixieConn(config, bufio.NewReadWriter(br, bw), nil, nil) + msg := make([]byte, 5) + n, err := ws.Read(msg) + if err != nil { + t.Errorf("read: %v", err) + } + if !bytes.Equal(b[1:6], msg[0:n]) { + t.Errorf("Read: expected %q got %q", b[1:6], msg[0:n]) + } + n, err = ws.Read(msg) + if err != os.EOF { + t.Errorf("read: %v", err) + } +} diff --git a/libgo/go/websocket/hybi.go b/libgo/go/websocket/hybi.go new file mode 100644 index 00000000000..fe08b3d738b --- /dev/null +++ b/libgo/go/websocket/hybi.go @@ -0,0 +1,550 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +// This file implements a protocol of hybi draft. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "http" + "io" + "io/ioutil" + "os" + "strings" + "url" +) + +const ( + websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + closeStatusNormal = 1000 + closeStatusGoingAway = 1001 + closeStatusProtocolError = 1002 + closeStatusUnsupportedData = 1003 + closeStatusFrameTooLarge = 1004 + closeStatusNoStatusRcvd = 1005 + closeStatusAbnormalClosure = 1006 + closeStatusBadMessageData = 1007 + closeStatusPolicyViolation = 1008 + closeStatusTooBigData = 1009 + closeStatusExtensionMismatch = 1010 + + maxControlFramePayloadLength = 125 +) + +var ( + ErrBadMaskingKey = &ProtocolError{"bad masking key"} + ErrBadPongMessage = &ProtocolError{"bad pong message"} + ErrBadClosingStatus = &ProtocolError{"bad closing status"} + ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"} + ErrNotImplemented = &ProtocolError{"not implemented"} +) + +// A hybiFrameHeader is a frame header as defined in hybi draft. +type hybiFrameHeader struct { + Fin bool + Rsv [3]bool + OpCode byte + Length int64 + MaskingKey []byte + + data *bytes.Buffer +} + +// A hybiFrameReader is a reader for hybi frame. +type hybiFrameReader struct { + reader io.Reader + + header hybiFrameHeader + pos int64 + length int +} + +func (frame *hybiFrameReader) Read(msg []byte) (n int, err os.Error) { + n, err = frame.reader.Read(msg) + if err != nil { + return 0, err + } + if frame.header.MaskingKey != nil { + for i := 0; i < n; i++ { + msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4] + frame.pos++ + } + } + return n, err +} + +func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode } + +func (frame *hybiFrameReader) HeaderReader() io.Reader { + if frame.header.data == nil { + return nil + } + if frame.header.data.Len() == 0 { + return nil + } + return frame.header.data +} + +func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil } + +func (frame *hybiFrameReader) Len() (n int) { return frame.length } + +// A hybiFrameReaderFactory creates new frame reader based on its frame type. +type hybiFrameReaderFactory struct { + *bufio.Reader +} + +// NewFrameReader reads a frame header from the connection, and creates new reader for the frame. +// See Section 5.2 Base Frameing protocol for detail. +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2 +func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err os.Error) { + hybiFrame := new(hybiFrameReader) + frame = hybiFrame + var header []byte + var b byte + // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits) + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0 + for i := 0; i < 3; i++ { + j := uint(6 - i) + hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0 + } + hybiFrame.header.OpCode = header[0] & 0x0f + + // Second byte. Mask/Payload len(7bits) + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + mask := (b & 0x80) != 0 + b &= 0x7f + lengthFields := 0 + switch { + case b <= 125: // Payload length 7bits. + hybiFrame.header.Length = int64(b) + case b == 126: // Payload length 7+16bits + lengthFields = 2 + case b == 127: // Payload length 7+64bits + lengthFields = 8 + } + for i := 0; i < lengthFields; i++ { + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b) + } + if mask { + // Masking key. 4 bytes. + for i := 0; i < 4; i++ { + b, err = buf.ReadByte() + if err != nil { + return + } + header = append(header, b) + hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b) + } + } + hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length) + hybiFrame.header.data = bytes.NewBuffer(header) + hybiFrame.length = len(header) + int(hybiFrame.header.Length) + return +} + +// A HybiFrameWriter is a writer for hybi frame. +type hybiFrameWriter struct { + writer *bufio.Writer + + header *hybiFrameHeader +} + +func (frame *hybiFrameWriter) Write(msg []byte) (n int, err os.Error) { + var header []byte + var b byte + if frame.header.Fin { + b |= 0x80 + } + for i := 0; i < 3; i++ { + if frame.header.Rsv[i] { + j := uint(6 - i) + b |= 1 << j + } + } + b |= frame.header.OpCode + header = append(header, b) + if frame.header.MaskingKey != nil { + b = 0x80 + } else { + b = 0 + } + lengthFields := 0 + length := len(msg) + switch { + case length <= 125: + b |= byte(length) + case length < 65536: + b |= 126 + lengthFields = 2 + default: + b |= 127 + lengthFields = 8 + } + header = append(header, b) + for i := 0; i < lengthFields; i++ { + j := uint((lengthFields - i - 1) * 8) + b = byte((length >> j) & 0xff) + header = append(header, b) + } + if frame.header.MaskingKey != nil { + if len(frame.header.MaskingKey) != 4 { + return 0, ErrBadMaskingKey + } + header = append(header, frame.header.MaskingKey...) + frame.writer.Write(header) + var data []byte + + for i := 0; i < length; i++ { + data = append(data, msg[i]^frame.header.MaskingKey[i%4]) + } + frame.writer.Write(data) + err = frame.writer.Flush() + return length, err + } + frame.writer.Write(header) + frame.writer.Write(msg) + err = frame.writer.Flush() + return length, err +} + +func (frame *hybiFrameWriter) Close() os.Error { return nil } + +type hybiFrameWriterFactory struct { + *bufio.Writer + needMaskingKey bool +} + +func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err os.Error) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType} + if buf.needMaskingKey { + frameHeader.MaskingKey, err = generateMaskingKey() + if err != nil { + return nil, err + } + } + return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil +} + +type hybiFrameHandler struct { + conn *Conn + payloadType byte +} + +func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err os.Error) { + if handler.conn.IsServerConn() { + // The client MUST mask all frames sent to the server. + if frame.(*hybiFrameReader).header.MaskingKey == nil { + handler.WriteClose(closeStatusProtocolError) + return nil, os.EOF + } + } else { + // The server MUST NOT mask all frames. + if frame.(*hybiFrameReader).header.MaskingKey != nil { + handler.WriteClose(closeStatusProtocolError) + return nil, os.EOF + } + } + if header := frame.HeaderReader(); header != nil { + io.Copy(ioutil.Discard, header) + } + switch frame.PayloadType() { + case ContinuationFrame: + frame.(*hybiFrameReader).header.OpCode = handler.payloadType + case TextFrame, BinaryFrame: + handler.payloadType = frame.PayloadType() + case CloseFrame: + return nil, os.EOF + case PingFrame: + pingMsg := make([]byte, maxControlFramePayloadLength) + n, err := io.ReadFull(frame, pingMsg) + if err != nil && err != io.ErrUnexpectedEOF { + return nil, err + } + io.Copy(ioutil.Discard, frame) + n, err = handler.WritePong(pingMsg[:n]) + if err != nil { + return nil, err + } + return nil, nil + case PongFrame: + return nil, ErrNotImplemented + } + return frame, nil +} + +func (handler *hybiFrameHandler) WriteClose(status int) (err os.Error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame) + if err != nil { + return err + } + msg := make([]byte, 2) + binary.BigEndian.PutUint16(msg, uint16(status)) + _, err = w.Write(msg) + w.Close() + return err +} + +func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err os.Error) { + handler.conn.wio.Lock() + defer handler.conn.wio.Unlock() + w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame) + if err != nil { + return 0, err + } + n, err = w.Write(msg) + w.Close() + return n, err +} + +// newHybiConn creates a new WebSocket connection speaking hybi draft protocol. +func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + if buf == nil { + br := bufio.NewReader(rwc) + bw := bufio.NewWriter(rwc) + buf = bufio.NewReadWriter(br, bw) + } + ws := &Conn{config: config, request: request, buf: buf, rwc: rwc, + frameReaderFactory: hybiFrameReaderFactory{buf.Reader}, + frameWriterFactory: hybiFrameWriterFactory{ + buf.Writer, request == nil}, + PayloadType: TextFrame, + defaultCloseStatus: closeStatusNormal} + ws.frameHandler = &hybiFrameHandler{conn: ws} + return ws +} + +// generateMaskingKey generates a masking key for a frame. +func generateMaskingKey() (maskingKey []byte, err os.Error) { + maskingKey = make([]byte, 4) + if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil { + return + } + return +} + +// genetateNonce geneates a nonce consisting of a randomly selected 16-byte +// value that has been base64-encoded. +func generateNonce() (nonce []byte) { + key := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + panic(err) + } + nonce = make([]byte, 24) + base64.StdEncoding.Encode(nonce, key) + return +} + +// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of +// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string. +func getNonceAccept(nonce []byte) (expected []byte, err os.Error) { + h := sha1.New() + if _, err = h.Write(nonce); err != nil { + return + } + if _, err = h.Write([]byte(websocketGUID)); err != nil { + return + } + expected = make([]byte, 28) + base64.StdEncoding.Encode(expected, h.Sum()) + return +} + +func isHybiVersion(version int) bool { + switch version { + case ProtocolVersionHybi08, ProtocolVersionHybi13: + return true + default: + } + return false +} + +// Client handhake described in draft-ietf-hybi-thewebsocket-protocol-17 +func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { + if !isHybiVersion(config.Version) { + panic("wrong protocol version.") + } + + bw.WriteString("GET " + config.Location.RawPath + " HTTP/1.1\r\n") + + bw.WriteString("Host: " + config.Location.Host + "\r\n") + bw.WriteString("Upgrade: websocket\r\n") + bw.WriteString("Connection: Upgrade\r\n") + nonce := generateNonce() + if config.handshakeData != nil { + nonce = []byte(config.handshakeData["key"]) + } + bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n") + if config.Version == ProtocolVersionHybi13 { + bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") + } else if config.Version == ProtocolVersionHybi08 { + bw.WriteString("Sec-WebSocket-Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") + } + bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") + if len(config.Protocol) > 0 { + bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") + } + // TODO(ukai): send extensions. + // TODO(ukai): send cookie if any. + + bw.WriteString("\r\n") + if err = bw.Flush(); err != nil { + return err + } + + resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) + if err != nil { + return err + } + if resp.StatusCode != 101 { + return ErrBadStatus + } + if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || + strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { + return ErrBadUpgrade + } + expectedAccept, err := getNonceAccept(nonce) + if err != nil { + return err + } + if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) { + return ErrChallengeResponse + } + if resp.Header.Get("Sec-WebSocket-Extensions") != "" { + return ErrUnsupportedExtensions + } + offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol") + if offeredProtocol != "" { + protocolMatched := false + for i := 0; i < len(config.Protocol); i++ { + if config.Protocol[i] == offeredProtocol { + protocolMatched = true + break + } + } + if !protocolMatched { + return ErrBadWebSocketProtocol + } + config.Protocol = []string{offeredProtocol} + } + + return nil +} + +// newHybiClientConn creates a client WebSocket connection after handshake. +func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { + return newHybiConn(config, buf, rwc, nil) +} + +// A HybiServerHandshaker performs a server handshake using hybi draft protocol. +type hybiServerHandshaker struct { + *Config + accept []byte +} + +func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err os.Error) { + c.Version = ProtocolVersionHybi13 + if req.Method != "GET" { + return http.StatusMethodNotAllowed, ErrBadRequestMethod + } + // HTTP version can be safely ignored. + + if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || + !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { + return http.StatusBadRequest, ErrNotWebSocket + } + + key := req.Header.Get("Sec-Websocket-Key") + if key == "" { + return http.StatusBadRequest, ErrChallengeResponse + } + version := req.Header.Get("Sec-Websocket-Version") + var origin string + switch version { + case "13": + c.Version = ProtocolVersionHybi13 + origin = req.Header.Get("Origin") + case "8": + c.Version = ProtocolVersionHybi08 + origin = req.Header.Get("Sec-Websocket-Origin") + default: + return http.StatusBadRequest, ErrBadWebSocketVersion + } + c.Origin, err = url.ParseRequest(origin) + if err != nil { + return http.StatusForbidden, err + } + var scheme string + if req.TLS != nil { + scheme = "wss" + } else { + scheme = "ws" + } + c.Location, err = url.ParseRequest(scheme + "://" + req.Host + req.URL.RawPath) + if err != nil { + return http.StatusBadRequest, err + } + protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) + protocols := strings.Split(protocol, ",") + for i := 0; i < len(protocols); i++ { + c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) + } + c.accept, err = getNonceAccept([]byte(key)) + if err != nil { + return http.StatusInternalServerError, err + } + return http.StatusSwitchingProtocols, nil +} + +func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err os.Error) { + if len(c.Protocol) > 0 { + if len(c.Protocol) != 1 { + return ErrBadWebSocketProtocol + } + } + buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + buf.WriteString("Upgrade: websocket\r\n") + buf.WriteString("Connection: Upgrade\r\n") + buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n") + if len(c.Protocol) > 0 { + buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") + } + // TODO(ukai): support extensions + buf.WriteString("\r\n") + return buf.Flush() +} + +func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHybiServerConn(c.Config, buf, rwc, request) +} + +// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol. +func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { + return newHybiConn(config, buf, rwc, request) +} diff --git a/libgo/go/websocket/hybi_test.go b/libgo/go/websocket/hybi_test.go new file mode 100644 index 00000000000..9db57e3f1b7 --- /dev/null +++ b/libgo/go/websocket/hybi_test.go @@ -0,0 +1,584 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bufio" + "bytes" + "fmt" + "http" + "os" + "strings" + "testing" + "url" +) + +// Test the getNonceAccept function with values in +// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 +func TestSecWebSocketAccept(t *testing.T) { + nonce := []byte("dGhlIHNhbXBsZSBub25jZQ==") + expected := []byte("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=") + accept, err := getNonceAccept(nonce) + if err != nil { + t.Errorf("getNonceAccept: returned error %v", err) + return + } + if !bytes.Equal(expected, accept) { + t.Errorf("getNonceAccept: expected %q got %q", expected, accept) + } +} + +func TestHybiClientHandshake(t *testing.T) { + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= +Sec-WebSocket-Protocol: chat + +`)) + var err os.Error + config := new(Config) + config.Location, err = url.ParseRequest("ws://server.example.com/chat") + if err != nil { + t.Fatal("location url", err) + } + config.Origin, err = url.ParseRequest("http://example.com") + if err != nil { + t.Fatal("origin url", err) + } + config.Protocol = append(config.Protocol, "chat") + config.Protocol = append(config.Protocol, "superchat") + config.Version = ProtocolVersionHybi13 + + config.handshakeData = map[string]string{ + "key": "dGhlIHNhbXBsZSBub25jZQ==", + } + err = hybiClientHandshake(config, br, bw) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + req, err := http.ReadRequest(bufio.NewReader(b)) + if err != nil { + t.Fatalf("read request: %v", err) + } + if req.Method != "GET" { + t.Errorf("request method expected GET, but got %q", req.Method) + } + if req.URL.Path != "/chat" { + t.Errorf("request path expected /chat, but got %q", req.URL.Path) + } + if req.Proto != "HTTP/1.1" { + t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto) + } + if req.Host != "server.example.com" { + t.Errorf("request Host expected server.example.com, but got %v", req.Host) + } + var expectedHeader = map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-Websocket-Key": config.handshakeData["key"], + "Origin": config.Origin.String(), + "Sec-Websocket-Protocol": "chat, superchat", + "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13), + } + for k, v := range expectedHeader { + if req.Header.Get(k) != v { + t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k))) + } + } +} + +func TestHybiClientHandshakeHybi08(t *testing.T) { + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= +Sec-WebSocket-Protocol: chat + +`)) + var err os.Error + config := new(Config) + config.Location, err = url.ParseRequest("ws://server.example.com/chat") + if err != nil { + t.Fatal("location url", err) + } + config.Origin, err = url.ParseRequest("http://example.com") + if err != nil { + t.Fatal("origin url", err) + } + config.Protocol = append(config.Protocol, "chat") + config.Protocol = append(config.Protocol, "superchat") + config.Version = ProtocolVersionHybi08 + + config.handshakeData = map[string]string{ + "key": "dGhlIHNhbXBsZSBub25jZQ==", + } + err = hybiClientHandshake(config, br, bw) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + req, err := http.ReadRequest(bufio.NewReader(b)) + if err != nil { + t.Fatalf("read request: %v", err) + } + if req.Method != "GET" { + t.Errorf("request method expected GET, but got %q", req.Method) + } + if req.URL.Path != "/chat" { + t.Errorf("request path expected /demo, but got %q", req.URL.Path) + } + if req.Proto != "HTTP/1.1" { + t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto) + } + if req.Host != "server.example.com" { + t.Errorf("request Host expected example.com, but got %v", req.Host) + } + var expectedHeader = map[string]string{ + "Connection": "Upgrade", + "Upgrade": "websocket", + "Sec-Websocket-Key": config.handshakeData["key"], + "Sec-Websocket-Origin": config.Origin.String(), + "Sec-Websocket-Protocol": "chat, superchat", + "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi08), + } + for k, v := range expectedHeader { + if req.Header.Get(k) != v { + t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k))) + } + } +} + +func TestHybiServerHandshake(t *testing.T) { + config := new(Config) + handshaker := &hybiServerHandshaker{Config: config} + br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 +Host: server.example.com +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== +Origin: http://example.com +Sec-WebSocket-Protocol: chat, superchat +Sec-WebSocket-Version: 13 + +`)) + req, err := http.ReadRequest(br) + if err != nil { + t.Fatal("request", err) + } + code, err := handshaker.ReadHandshake(br, req) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + if code != http.StatusSwitchingProtocols { + t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) + } + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + + config.Protocol = []string{"chat"} + + err = handshaker.AcceptHandshake(bw) + if err != nil { + t.Errorf("handshake response failed: %v", err) + } + expectedResponse := strings.Join([]string{ + "HTTP/1.1 101 Switching Protocols", + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "Sec-WebSocket-Protocol: chat", + "", ""}, "\r\n") + + if b.String() != expectedResponse { + t.Errorf("handshake expected %q but got %q", expectedResponse, b.String()) + } +} + +func TestHybiServerHandshakeHybi08(t *testing.T) { + config := new(Config) + handshaker := &hybiServerHandshaker{Config: config} + br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 +Host: server.example.com +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== +Sec-WebSocket-Origin: http://example.com +Sec-WebSocket-Protocol: chat, superchat +Sec-WebSocket-Version: 8 + +`)) + req, err := http.ReadRequest(br) + if err != nil { + t.Fatal("request", err) + } + code, err := handshaker.ReadHandshake(br, req) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + if code != http.StatusSwitchingProtocols { + t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) + } + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + + config.Protocol = []string{"chat"} + + err = handshaker.AcceptHandshake(bw) + if err != nil { + t.Errorf("handshake response failed: %v", err) + } + expectedResponse := strings.Join([]string{ + "HTTP/1.1 101 Switching Protocols", + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "Sec-WebSocket-Protocol: chat", + "", ""}, "\r\n") + + if b.String() != expectedResponse { + t.Errorf("handshake expected %q but got %q", expectedResponse, b.String()) + } +} + +func TestHybiServerHandshakeHybiBadVersion(t *testing.T) { + config := new(Config) + handshaker := &hybiServerHandshaker{Config: config} + br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 +Host: server.example.com +Upgrade: websocket +Connection: Upgrade +Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== +Sec-WebSocket-Origin: http://example.com +Sec-WebSocket-Protocol: chat, superchat +Sec-WebSocket-Version: 9 + +`)) + req, err := http.ReadRequest(br) + if err != nil { + t.Fatal("request", err) + } + code, err := handshaker.ReadHandshake(br, req) + if err != ErrBadWebSocketVersion { + t.Errorf("handshake expected err %q but got %q", ErrBadWebSocketVersion, err) + } + if code != http.StatusBadRequest { + t.Errorf("status expected %q but got %q", http.StatusBadRequest, code) + } +} + +func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []byte, frameHeader *hybiFrameHeader) { + b := bytes.NewBuffer([]byte{}) + frameWriterFactory := &hybiFrameWriterFactory{bufio.NewWriter(b), false} + w, _ := frameWriterFactory.NewFrameWriter(TextFrame) + w.(*hybiFrameWriter).header = frameHeader + _, err := w.Write(testPayload) + w.Close() + if err != nil { + t.Errorf("Write error %q", err) + } + var expectedFrame []byte + expectedFrame = append(expectedFrame, testHeader...) + expectedFrame = append(expectedFrame, testMaskedPayload...) + if !bytes.Equal(expectedFrame, b.Bytes()) { + t.Errorf("frame expected %q got %q", expectedFrame, b.Bytes()) + } + frameReaderFactory := &hybiFrameReaderFactory{bufio.NewReader(b)} + r, err := frameReaderFactory.NewFrameReader() + if err != nil { + t.Errorf("Read error %q", err) + } + if header := r.HeaderReader(); header == nil { + t.Errorf("no header") + } else { + actualHeader := make([]byte, r.Len()) + n, err := header.Read(actualHeader) + if err != nil { + t.Errorf("Read header error %q", err) + } else { + if n < len(testHeader) { + t.Errorf("header too short %q got %q", testHeader, actualHeader[:n]) + } + if !bytes.Equal(testHeader, actualHeader[:n]) { + t.Errorf("header expected %q got %q", testHeader, actualHeader[:n]) + } + } + } + if trailer := r.TrailerReader(); trailer != nil { + t.Errorf("unexpected trailer %q", trailer) + } + frame := r.(*hybiFrameReader) + if frameHeader.Fin != frame.header.Fin || + frameHeader.OpCode != frame.header.OpCode || + len(testPayload) != int(frame.header.Length) { + t.Errorf("mismatch %v (%d) vs %v", frameHeader, len(testPayload), frame) + } + payload := make([]byte, len(testPayload)) + _, err = r.Read(payload) + if err != nil { + t.Errorf("read %v", err) + } + if !bytes.Equal(testPayload, payload) { + t.Errorf("payload %q vs %q", testPayload, payload) + } +} + +func TestHybiShortTextFrame(t *testing.T) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame} + payload := []byte("hello") + testHybiFrame(t, []byte{0x81, 0x05}, payload, payload, frameHeader) + + payload = make([]byte, 125) + testHybiFrame(t, []byte{0x81, 125}, payload, payload, frameHeader) +} + +func TestHybiShortMaskedTextFrame(t *testing.T) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame, + MaskingKey: []byte{0xcc, 0x55, 0x80, 0x20}} + payload := []byte("hello") + maskedPayload := []byte{0xa4, 0x30, 0xec, 0x4c, 0xa3} + header := []byte{0x81, 0x85} + header = append(header, frameHeader.MaskingKey...) + testHybiFrame(t, header, payload, maskedPayload, frameHeader) +} + +func TestHybiShortBinaryFrame(t *testing.T) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: BinaryFrame} + payload := []byte("hello") + testHybiFrame(t, []byte{0x82, 0x05}, payload, payload, frameHeader) + + payload = make([]byte, 125) + testHybiFrame(t, []byte{0x82, 125}, payload, payload, frameHeader) +} + +func TestHybiControlFrame(t *testing.T) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame} + payload := []byte("hello") + testHybiFrame(t, []byte{0x89, 0x05}, payload, payload, frameHeader) + + frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame} + testHybiFrame(t, []byte{0x8A, 0x05}, payload, payload, frameHeader) + + frameHeader = &hybiFrameHeader{Fin: true, OpCode: CloseFrame} + payload = []byte{0x03, 0xe8} // 1000 + testHybiFrame(t, []byte{0x88, 0x02}, payload, payload, frameHeader) +} + +func TestHybiLongFrame(t *testing.T) { + frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame} + payload := make([]byte, 126) + testHybiFrame(t, []byte{0x81, 126, 0x00, 126}, payload, payload, frameHeader) + + payload = make([]byte, 65535) + testHybiFrame(t, []byte{0x81, 126, 0xff, 0xff}, payload, payload, frameHeader) + + payload = make([]byte, 65536) + testHybiFrame(t, []byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, payload, payload, frameHeader) +} + +func TestHybiClientRead(t *testing.T) { + wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o', + 0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping + 0x81, 0x05, 'w', 'o', 'r', 'l', 'd'} + br := bufio.NewReader(bytes.NewBuffer(wireData)) + bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) + conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil) + + msg := make([]byte, 512) + n, err := conn.Read(msg) + if err != nil { + t.Errorf("read 1st frame, error %q", err) + } + if n != 5 { + t.Errorf("read 1st frame, expect 5, got %d", n) + } + if !bytes.Equal(wireData[2:7], msg[:n]) { + t.Errorf("read 1st frame %v, got %v", wireData[2:7], msg[:n]) + } + n, err = conn.Read(msg) + if err != nil { + t.Errorf("read 2nd frame, error %q", err) + } + if n != 5 { + t.Errorf("read 2nd frame, expect 5, got %d", n) + } + if !bytes.Equal(wireData[16:21], msg[:n]) { + t.Errorf("read 2nd frame %v, got %v", wireData[16:21], msg[:n]) + } + n, err = conn.Read(msg) + if err == nil { + t.Errorf("read not EOF") + } + if n != 0 { + t.Errorf("expect read 0, got %d", n) + } +} + +func TestHybiShortRead(t *testing.T) { + wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o', + 0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping + 0x81, 0x05, 'w', 'o', 'r', 'l', 'd'} + br := bufio.NewReader(bytes.NewBuffer(wireData)) + bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) + conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil) + + step := 0 + pos := 0 + expectedPos := []int{2, 5, 16, 19} + expectedLen := []int{3, 2, 3, 2} + for { + msg := make([]byte, 3) + n, err := conn.Read(msg) + if step >= len(expectedPos) { + if err == nil { + t.Errorf("read not EOF") + } + if n != 0 { + t.Errorf("expect read 0, got %d", n) + } + return + } + pos = expectedPos[step] + endPos := pos + expectedLen[step] + if err != nil { + t.Errorf("read from %d, got error %q", pos, err) + return + } + if n != endPos-pos { + t.Errorf("read from %d, expect %d, got %d", pos, endPos-pos, n) + } + if !bytes.Equal(wireData[pos:endPos], msg[:n]) { + t.Errorf("read from %d, frame %v, got %v", pos, wireData[pos:endPos], msg[:n]) + } + step++ + } +} + +func TestHybiServerRead(t *testing.T) { + wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20, + 0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello + 0x89, 0x85, 0xcc, 0x55, 0x80, 0x20, + 0xa4, 0x30, 0xec, 0x4c, 0xa3, // ping: hello + 0x81, 0x85, 0xed, 0x83, 0xb4, 0x24, + 0x9a, 0xec, 0xc6, 0x48, 0x89, // world + } + br := bufio.NewReader(bytes.NewBuffer(wireData)) + bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) + conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request)) + + expected := [][]byte{[]byte("hello"), []byte("world")} + + msg := make([]byte, 512) + n, err := conn.Read(msg) + if err != nil { + t.Errorf("read 1st frame, error %q", err) + } + if n != 5 { + t.Errorf("read 1st frame, expect 5, got %d", n) + } + if !bytes.Equal(expected[0], msg[:n]) { + t.Errorf("read 1st frame %q, got %q", expected[0], msg[:n]) + } + + n, err = conn.Read(msg) + if err != nil { + t.Errorf("read 2nd frame, error %q", err) + } + if n != 5 { + t.Errorf("read 2nd frame, expect 5, got %d", n) + } + if !bytes.Equal(expected[1], msg[:n]) { + t.Errorf("read 2nd frame %q, got %q", expected[1], msg[:n]) + } + + n, err = conn.Read(msg) + if err == nil { + t.Errorf("read not EOF") + } + if n != 0 { + t.Errorf("expect read 0, got %d", n) + } +} + +func TestHybiServerReadWithoutMasking(t *testing.T) { + wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o'} + br := bufio.NewReader(bytes.NewBuffer(wireData)) + bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) + conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request)) + // server MUST close the connection upon receiving a non-masked frame. + msg := make([]byte, 512) + _, err := conn.Read(msg) + if err != os.EOF { + t.Errorf("read 1st frame, expect %q, but got %q", os.EOF, err) + } +} + +func TestHybiClientReadWithMasking(t *testing.T) { + wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20, + 0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello + } + br := bufio.NewReader(bytes.NewBuffer(wireData)) + bw := bufio.NewWriter(bytes.NewBuffer([]byte{})) + conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil) + + // client MUST close the connection upon receiving a masked frame. + msg := make([]byte, 512) + _, err := conn.Read(msg) + if err != os.EOF { + t.Errorf("read 1st frame, expect %q, but got %q", os.EOF, err) + } +} + +// Test the hybiServerHandshaker supports firefox implementation and +// checks Connection request header include (but it's not necessary +// equal to) "upgrade" +func TestHybiServerFirefoxHandshake(t *testing.T) { + config := new(Config) + handshaker := &hybiServerHandshaker{Config: config} + br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1 +Host: server.example.com +Upgrade: websocket +Connection: keep-alive, upgrade +Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== +Origin: http://example.com +Sec-WebSocket-Protocol: chat, superchat +Sec-WebSocket-Version: 13 + +`)) + req, err := http.ReadRequest(br) + if err != nil { + t.Fatal("request", err) + } + code, err := handshaker.ReadHandshake(br, req) + if err != nil { + t.Errorf("handshake failed: %v", err) + } + if code != http.StatusSwitchingProtocols { + t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) + } + b := bytes.NewBuffer([]byte{}) + bw := bufio.NewWriter(b) + + config.Protocol = []string{"chat"} + + err = handshaker.AcceptHandshake(bw) + if err != nil { + t.Errorf("handshake response failed: %v", err) + } + expectedResponse := strings.Join([]string{ + "HTTP/1.1 101 Switching Protocols", + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "Sec-WebSocket-Protocol: chat", + "", ""}, "\r\n") + + if b.String() != expectedResponse { + t.Errorf("handshake expected %q but got %q", expectedResponse, b.String()) + } +} diff --git a/libgo/go/websocket/server.go b/libgo/go/websocket/server.go index e0e7c872db4..a1d1d48600f 100644 --- a/libgo/go/websocket/server.go +++ b/libgo/go/websocket/server.go @@ -5,11 +5,48 @@ package websocket import ( + "bufio" + "fmt" "http" "io" - "strings" + "os" ) +func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request) (conn *Conn, err os.Error) { + config := new(Config) + var hs serverHandshaker = &hybiServerHandshaker{Config: config} + code, err := hs.ReadHandshake(buf.Reader, req) + if err == ErrBadWebSocketVersion { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion) + buf.WriteString("\r\n") + buf.WriteString(err.String()) + return + } + if err != nil { + hs = &hixie76ServerHandshaker{Config: config} + code, err = hs.ReadHandshake(buf.Reader, req) + } + if err != nil { + hs = &hixie75ServerHandshaker{Config: config} + code, err = hs.ReadHandshake(buf.Reader, req) + } + if err != nil { + fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code)) + buf.WriteString("\r\n") + buf.WriteString(err.String()) + return + } + config.Protocol = nil + + err = hs.AcceptHandshake(buf.Writer) + if err != nil { + return + } + conn = hs.NewServerConn(buf, rwc, req) + return +} + /* Handler is an interface to a WebSocket. @@ -23,7 +60,7 @@ A trivial example server: "websocket" ) - // Echo the data received on the Web Socket. + // Echo the data received on the WebSocket. func EchoServer(ws *websocket.Conn) { io.Copy(ws, ws); } @@ -38,26 +75,8 @@ A trivial example server: */ type Handler func(*Conn) -/* -Gets key number from Sec-WebSocket-Key<n>: field as described -in 5.2 Sending the server's opening handshake, 4. -*/ -func getKeyNumber(s string) (r uint32) { - // 4. Let /key-number_n/ be the digits (characters in the range - // U+0030 DIGIT ZERO (0) to U+0039 DIGIT NINE (9)) in /key_1/, - // interpreted as a base ten integer, ignoring all other characters - // in /key_n/. - r = 0 - for i := 0; i < len(s); i++ { - if s[i] >= '0' && s[i] <= '9' { - r = r*10 + uint32(s[i]) - '0' - } - } - return -} - // ServeHTTP implements the http.Handler interface for a Web Socket -func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { rwc, buf, err := w.(http.Hijacker).Hijack() if err != nil { panic("Hijack failed: " + err.String()) @@ -67,153 +86,12 @@ func (f Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // the client did not send a handshake that matches with protocol // specification. defer rwc.Close() - - if req.Method != "GET" { - return - } - // HTTP version can be safely ignored. - - if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || - strings.ToLower(req.Header.Get("Connection")) != "upgrade" { - return - } - - // TODO(ukai): check Host - origin := req.Header.Get("Origin") - if origin == "" { - return - } - - key1 := req.Header.Get("Sec-Websocket-Key1") - if key1 == "" { - return - } - key2 := req.Header.Get("Sec-Websocket-Key2") - if key2 == "" { - return - } - key3 := make([]byte, 8) - if _, err := io.ReadFull(buf, key3); err != nil { - return - } - - var location string - if req.TLS != nil { - location = "wss://" + req.Host + req.URL.RawPath - } else { - location = "ws://" + req.Host + req.URL.RawPath - } - - // Step 4. get key number in Sec-WebSocket-Key<n> fields. - keyNumber1 := getKeyNumber(key1) - keyNumber2 := getKeyNumber(key2) - - // Step 5. get number of spaces in Sec-WebSocket-Key<n> fields. - space1 := uint32(strings.Count(key1, " ")) - space2 := uint32(strings.Count(key2, " ")) - if space1 == 0 || space2 == 0 { - return - } - - // Step 6. key number must be an integral multiple of spaces. - if keyNumber1%space1 != 0 || keyNumber2%space2 != 0 { - return - } - - // Step 7. let part be key number divided by spaces. - part1 := keyNumber1 / space1 - part2 := keyNumber2 / space2 - - // Step 8. let challenge be concatenation of part1, part2 and key3. - // Step 9. get MD5 fingerprint of challenge. - response, err := getChallengeResponse(part1, part2, key3) - if err != nil { - return - } - - // Step 10. send response status line. - buf.WriteString("HTTP/1.1 101 WebSocket Protocol Handshake\r\n") - // Step 11. send response headers. - buf.WriteString("Upgrade: WebSocket\r\n") - buf.WriteString("Connection: Upgrade\r\n") - buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n") - buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n") - protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) - if protocol != "" { - buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n") - } - // Step 12. send CRLF. - buf.WriteString("\r\n") - // Step 13. send response data. - buf.Write(response) - if err := buf.Flush(); err != nil { - return - } - ws := newConn(origin, location, protocol, buf, rwc) - ws.Request = req - f(ws) -} - -/* -Draft75Handler is an interface to a WebSocket based on the -(soon obsolete) draft-hixie-thewebsocketprotocol-75. -*/ -type Draft75Handler func(*Conn) - -// ServeHTTP implements the http.Handler interface for a Web Socket. -func (f Draft75Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if req.Method != "GET" || req.Proto != "HTTP/1.1" { - w.WriteHeader(http.StatusBadRequest) - io.WriteString(w, "Unexpected request") - return - } - if req.Header.Get("Upgrade") != "WebSocket" { - w.WriteHeader(http.StatusBadRequest) - io.WriteString(w, "missing Upgrade: WebSocket header") - return - } - if req.Header.Get("Connection") != "Upgrade" { - w.WriteHeader(http.StatusBadRequest) - io.WriteString(w, "missing Connection: Upgrade header") - return - } - origin := strings.TrimSpace(req.Header.Get("Origin")) - if origin == "" { - w.WriteHeader(http.StatusBadRequest) - io.WriteString(w, "missing Origin header") - return - } - - rwc, buf, err := w.(http.Hijacker).Hijack() + conn, err := newServerConn(rwc, buf, req) if err != nil { - panic("Hijack failed: " + err.String()) return } - defer rwc.Close() - - var location string - if req.TLS != nil { - location = "wss://" + req.Host + req.URL.RawPath - } else { - location = "ws://" + req.Host + req.URL.RawPath - } - - // TODO(ukai): verify origin,location,protocol. - - buf.WriteString("HTTP/1.1 101 Web Socket Protocol Handshake\r\n") - buf.WriteString("Upgrade: WebSocket\r\n") - buf.WriteString("Connection: Upgrade\r\n") - buf.WriteString("WebSocket-Origin: " + origin + "\r\n") - buf.WriteString("WebSocket-Location: " + location + "\r\n") - protocol := strings.TrimSpace(req.Header.Get("Websocket-Protocol")) - // canonical header key of WebSocket-Protocol. - if protocol != "" { - buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n") - } - buf.WriteString("\r\n") - if err := buf.Flush(); err != nil { - return + if conn == nil { + panic("unepxected nil conn") } - ws := newConn(origin, location, protocol, buf, rwc) - f(ws) + h(conn) } diff --git a/libgo/go/websocket/websocket.go b/libgo/go/websocket/websocket.go index 7447cf85215..a3750dde115 100644 --- a/libgo/go/websocket/websocket.go +++ b/libgo/go/websocket/websocket.go @@ -2,145 +2,246 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package websocket implements a client and server for the Web Socket protocol. -// The protocol is defined at http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol +// Package websocket implements a client and server for the WebSocket protocol. +// The protocol is defined at http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol package websocket -// TODO(ukai): -// better logging. - import ( "bufio" - "crypto/md5" - "encoding/binary" + "crypto/tls" "http" "io" + "io/ioutil" + "json" "net" "os" + "sync" + "url" ) -// WebSocketAddr is an implementation of net.Addr for Web Sockets. -type WebSocketAddr string +const ( + ProtocolVersionHixie75 = -75 + ProtocolVersionHixie76 = -76 + ProtocolVersionHybi00 = 0 + ProtocolVersionHybi08 = 8 + ProtocolVersionHybi13 = 13 + ProtocolVersionHybi = ProtocolVersionHybi13 + SupportedProtocolVersion = "13, 8" + + ContinuationFrame = 0 + TextFrame = 1 + BinaryFrame = 2 + CloseFrame = 8 + PingFrame = 9 + PongFrame = 10 + UnknownFrame = 255 +) -// Network returns the network type for a Web Socket, "websocket". -func (addr WebSocketAddr) Network() string { return "websocket" } +// WebSocket protocol errors. +type ProtocolError struct { + ErrorString string +} -// String returns the network address for a Web Socket. -func (addr WebSocketAddr) String() string { return string(addr) } +func (err *ProtocolError) String() string { return err.ErrorString } -const ( - stateFrameByte = iota - stateFrameLength - stateFrameData - stateFrameTextData +var ( + ErrBadProtocolVersion = &ProtocolError{"bad protocol version"} + ErrBadScheme = &ProtocolError{"bad scheme"} + ErrBadStatus = &ProtocolError{"bad status"} + ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} + ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} + ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"} + ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"} + ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"} + ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"} + ErrBadFrame = &ProtocolError{"bad frame"} + ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"} + ErrNotWebSocket = &ProtocolError{"not websocket protocol"} + ErrBadRequestMethod = &ProtocolError{"bad method"} + ErrNotSupported = &ProtocolError{"not supported"} ) -// Conn is a channel to communicate to a Web Socket. -// It implements the net.Conn interface. +// Addr is an implementation of net.Addr for WebSocket. +type Addr struct { + *url.URL +} + +// Network returns the network type for a WebSocket, "websocket". +func (addr *Addr) Network() string { return "websocket" } + +// Config is a WebSocket configuration +type Config struct { + // A WebSocket server address. + Location *url.URL + + // A Websocket client origin. + Origin *url.URL + + // WebSocket subprotocols. + Protocol []string + + // WebSocket protocol version. + Version int + + // TLS config for secure WebSocket (wss). + TlsConfig *tls.Config + + handshakeData map[string]string +} + +// serverHandshaker is an interface to handle WebSocket server side handshake. +type serverHandshaker interface { + // ReadHandshake reads handshake request message from client. + // Returns http response code and error if any. + ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err os.Error) + + // AcceptHandshake accepts the client handshake request and sends + // handshake response back to client. + AcceptHandshake(buf *bufio.Writer) (err os.Error) + + // NewServerConn creates a new WebSocket connection. + NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn) +} + +// frameReader is an interface to read a WebSocket frame. +type frameReader interface { + // Reader is to read payload of the frame. + io.Reader + + // PayloadType returns payload type. + PayloadType() byte + + // HeaderReader returns a reader to read header of the frame. + HeaderReader() io.Reader + + // TrailerReader returns a reader to read trailer of the frame. + // If it returns nil, there is no trailer in the frame. + TrailerReader() io.Reader + + // Len returns total length of the frame, including header and trailer. + Len() int +} + +// frameReaderFactory is an interface to creates new frame reader. +type frameReaderFactory interface { + NewFrameReader() (r frameReader, err os.Error) +} + +// frameWriter is an interface to write a WebSocket frame. +type frameWriter interface { + // Writer is to write playload of the frame. + io.WriteCloser +} + +// frameWriterFactory is an interface to create new frame writer. +type frameWriterFactory interface { + NewFrameWriter(payloadType byte) (w frameWriter, err os.Error) +} + +type frameHandler interface { + HandleFrame(frame frameReader) (r frameReader, err os.Error) + WriteClose(status int) (err os.Error) +} + +// Conn represents a WebSocket connection. type Conn struct { - // The origin URI for the Web Socket. - Origin string - // The location URI for the Web Socket. - Location string - // The subprotocol for the Web Socket. - Protocol string - // The initial http Request (for the Server side only). - Request *http.Request + config *Config + request *http.Request buf *bufio.ReadWriter rwc io.ReadWriteCloser - // It holds text data in previous Read() that failed with small buffer. - data []byte - reading bool -} + rio sync.Mutex + frameReaderFactory + frameReader -// newConn creates a new Web Socket. -func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { - if buf == nil { - br := bufio.NewReader(rwc) - bw := bufio.NewWriter(rwc) - buf = bufio.NewReadWriter(br, bw) - } - ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc} - return ws + wio sync.Mutex + frameWriterFactory + + frameHandler + PayloadType byte + defaultCloseStatus int } -// Read implements the io.Reader interface for a Conn. +// Read implements the io.Reader interface: +// it reads data of a frame from the WebSocket connection. +// if msg is not large enough for the frame data, it fills the msg and next Read +// will read the rest of the frame data. +// it reads Text frame or Binary frame. func (ws *Conn) Read(msg []byte) (n int, err os.Error) { -Frame: - for !ws.reading && len(ws.data) == 0 { - // Beginning of frame, possibly. - b, err := ws.buf.ReadByte() + ws.rio.Lock() + defer ws.rio.Unlock() +again: + if ws.frameReader == nil { + frame, err := ws.frameReaderFactory.NewFrameReader() if err != nil { return 0, err } - if b&0x80 == 0x80 { - // Skip length frame. - length := 0 - for { - c, err := ws.buf.ReadByte() - if err != nil { - return 0, err - } - length = length*128 + int(c&0x7f) - if c&0x80 == 0 { - break - } - } - for length > 0 { - _, err := ws.buf.ReadByte() - if err != nil { - return 0, err - } - } - continue Frame + ws.frameReader, err = ws.frameHandler.HandleFrame(frame) + if err != nil { + return 0, err } - // In text mode - if b != 0 { - // Skip this frame - for { - c, err := ws.buf.ReadByte() - if err != nil { - return 0, err - } - if c == '\xff' { - break - } - } - continue Frame + if ws.frameReader == nil { + goto again } - ws.reading = true } - if len(ws.data) == 0 { - ws.data, err = ws.buf.ReadSlice('\xff') - if err == nil { - ws.reading = false - ws.data = ws.data[:len(ws.data)-1] // trim \xff + n, err = ws.frameReader.Read(msg) + if err == os.EOF { + if trailer := ws.frameReader.TrailerReader(); trailer != nil { + io.Copy(ioutil.Discard, trailer) } + ws.frameReader = nil + goto again } - n = copy(msg, ws.data) - ws.data = ws.data[n:] return n, err } -// Write implements the io.Writer interface for a Conn. +// Write implements the io.Writer interface: +// it writes data as a frame to the WebSocket connection. func (ws *Conn) Write(msg []byte) (n int, err os.Error) { - ws.buf.WriteByte(0) - ws.buf.Write(msg) - ws.buf.WriteByte(0xff) - err = ws.buf.Flush() - return len(msg), err + ws.wio.Lock() + defer ws.wio.Unlock() + w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType) + if err != nil { + return 0, err + } + n, err = w.Write(msg) + w.Close() + if err != nil { + return n, err + } + return n, err +} + +// Close implements the io.Closer interface. +func (ws *Conn) Close() os.Error { + err := ws.frameHandler.WriteClose(ws.defaultCloseStatus) + if err != nil { + return err + } + return ws.rwc.Close() } -// Close implements the io.Closer interface for a Conn. -func (ws *Conn) Close() os.Error { return ws.rwc.Close() } +func (ws *Conn) IsClientConn() bool { return ws.request == nil } +func (ws *Conn) IsServerConn() bool { return ws.request != nil } -// LocalAddr returns the WebSocket Origin for the connection. -func (ws *Conn) LocalAddr() net.Addr { return WebSocketAddr(ws.Origin) } +// LocalAddr returns the WebSocket Origin for the connection for client, or +// the WebSocket location for server. +func (ws *Conn) LocalAddr() net.Addr { + if ws.IsClientConn() { + return &Addr{ws.config.Origin} + } + return &Addr{ws.config.Location} +} -// RemoteAddr returns the WebSocket locations for the connection. -func (ws *Conn) RemoteAddr() net.Addr { return WebSocketAddr(ws.Location) } +// RemoteAddr returns the WebSocket location for the connection for client, or +// the Websocket Origin for server. +func (ws *Conn) RemoteAddr() net.Addr { + if ws.IsClientConn() { + return &Addr{ws.config.Location} + } + return &Addr{ws.config.Origin} +} // SetTimeout sets the connection's network timeout in nanoseconds. func (ws *Conn) SetTimeout(nsec int64) os.Error { @@ -166,27 +267,143 @@ func (ws *Conn) SetWriteTimeout(nsec int64) os.Error { return os.EINVAL } -// getChallengeResponse computes the expected response from the -// challenge as described in section 5.1 Opening Handshake steps 42 to -// 43 of http://www.whatwg.org/specs/web-socket-protocol/ -func getChallengeResponse(number1, number2 uint32, key3 []byte) (expected []byte, err os.Error) { - // 41. Let /challenge/ be the concatenation of /number_1/, expressed - // a big-endian 32 bit integer, /number_2/, expressed in a big- - // endian 32 bit integer, and the eight bytes of /key_3/ in the - // order they were sent to the wire. - challenge := make([]byte, 16) - binary.BigEndian.PutUint32(challenge[0:], number1) - binary.BigEndian.PutUint32(challenge[4:], number2) - copy(challenge[8:], key3) +// Config returns the WebSocket config. +func (ws *Conn) Config() *Config { return ws.config } + +// Request returns the http request upgraded to the WebSocket. +// It is nil for client side. +func (ws *Conn) Request() *http.Request { return ws.request } + +// Codec represents a symmetric pair of functions that implement a codec. +type Codec struct { + Marshal func(v interface{}) (data []byte, payloadType byte, err os.Error) + Unmarshal func(data []byte, payloadType byte, v interface{}) (err os.Error) +} + +// Send sends v marshaled by cd.Marshal as single frame to ws. +func (cd Codec) Send(ws *Conn, v interface{}) (err os.Error) { + if err != nil { + return err + } + data, payloadType, err := cd.Marshal(v) + if err != nil { + return err + } + ws.wio.Lock() + defer ws.wio.Unlock() + w, err := ws.frameWriterFactory.NewFrameWriter(payloadType) + _, err = w.Write(data) + w.Close() + return err +} + +// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores in v. +func (cd Codec) Receive(ws *Conn, v interface{}) (err os.Error) { + ws.rio.Lock() + defer ws.rio.Unlock() + if ws.frameReader != nil { + _, err = io.Copy(ioutil.Discard, ws.frameReader) + if err != nil { + return err + } + ws.frameReader = nil + } +again: + frame, err := ws.frameReaderFactory.NewFrameReader() + if err != nil { + return err + } + frame, err = ws.frameHandler.HandleFrame(frame) + if err != nil { + return err + } + if frame == nil { + goto again + } + payloadType := frame.PayloadType() + data, err := ioutil.ReadAll(frame) + if err != nil { + return err + } + return cd.Unmarshal(data, payloadType, v) +} - // 42. Let /expected/ be the MD5 fingerprint of /challenge/ as a big- - // endian 128 bit string. - h := md5.New() - if _, err = h.Write(challenge); err != nil { - return +func marshal(v interface{}) (msg []byte, payloadType byte, err os.Error) { + switch data := v.(type) { + case string: + return []byte(data), TextFrame, nil + case []byte: + return data, BinaryFrame, nil } - expected = h.Sum() - return + return nil, UnknownFrame, ErrNotSupported } -var _ net.Conn = (*Conn)(nil) // compile-time check that *Conn implements net.Conn. +func unmarshal(msg []byte, payloadType byte, v interface{}) (err os.Error) { + switch data := v.(type) { + case *string: + *data = string(msg) + return nil + case *[]byte: + *data = msg + return nil + } + return ErrNotSupported +} + +/* +Message is a codec to send/receive text/binary data in a frame on WebSocket connection. +To send/receive text frame, use string type. +To send/receive binary frame, use []byte type. + +Trivial usage: + + import "websocket" + + // receive text frame + var message string + websocket.Message.Receive(ws, &message) + + // send text frame + message = "hello" + websocket.Message.Send(ws, message) + + // receive binary frame + var data []byte + websocket.Message.Receive(ws, &data) + + // send binary frame + data = []byte{0, 1, 2} + websocket.Message.Send(ws, data) + +*/ +var Message = Codec{marshal, unmarshal} + +func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err os.Error) { + msg, err = json.Marshal(v) + return msg, TextFrame, err +} + +func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err os.Error) { + return json.Unmarshal(msg, v) +} + +/* +JSON is a codec to send/receive JSON data in a frame from a WebSocket connection. + +Trival usage: + + import "websocket" + + type T struct { + Msg string + Count int + } + + // receive JSON type T + var data T + websocket.JSON.Receive(ws, &data) + + // send JSON type T + websocket.JSON.Send(ws, data) +*/ +var JSON = Codec{jsonMarshal, jsonUnmarshal} diff --git a/libgo/go/websocket/websocket_test.go b/libgo/go/websocket/websocket_test.go index 71c3c8514b7..240af4e49bb 100644 --- a/libgo/go/websocket/websocket_test.go +++ b/libgo/go/websocket/websocket_test.go @@ -5,7 +5,6 @@ package websocket import ( - "bufio" "bytes" "fmt" "http" @@ -13,6 +12,7 @@ import ( "io" "log" "net" + "strings" "sync" "testing" "url" @@ -23,31 +23,38 @@ var once sync.Once func echoServer(ws *Conn) { io.Copy(ws, ws) } +type Count struct { + S string + N int +} + +func countServer(ws *Conn) { + for { + var count Count + err := JSON.Receive(ws, &count) + if err != nil { + return + } + count.N++ + count.S = strings.Repeat(count.S, count.N) + err = JSON.Send(ws, count) + if err != nil { + return + } + } +} + func startServer() { http.Handle("/echo", Handler(echoServer)) - http.Handle("/echoDraft75", Draft75Handler(echoServer)) + http.Handle("/count", Handler(countServer)) server := httptest.NewServer(nil) serverAddr = server.Listener.Addr().String() log.Print("Test WebSocket server listening on ", serverAddr) } -// Test the getChallengeResponse function with values from section -// 5.1 of the specification steps 18, 26, and 43 from -// http://www.whatwg.org/specs/web-socket-protocol/ -func TestChallenge(t *testing.T) { - var part1 uint32 = 777007543 - var part2 uint32 = 114997259 - key3 := []byte{0x47, 0x30, 0x22, 0x2D, 0x5A, 0x3F, 0x47, 0x58} - expected := []byte("0st3Rl&q-2ZU^weu") - - response, err := getChallengeResponse(part1, part2, key3) - if err != nil { - t.Errorf("getChallengeResponse: returned error %v", err) - return - } - if !bytes.Equal(expected, response) { - t.Errorf("getChallengeResponse: expected %q got %q", expected, response) - } +func newConfig(t *testing.T, path string) *Config { + config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost") + return config } func TestEcho(t *testing.T) { @@ -58,19 +65,18 @@ func TestEcho(t *testing.T) { if err != nil { t.Fatal("dialing", err) } - ws, err := newClient("/echo", "localhost", "http://localhost", - "ws://localhost/echo", "", client, handshake) + conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } msg := []byte("hello, world\n") - if _, err := ws.Write(msg); err != nil { + if _, err := conn.Write(msg); err != nil { t.Errorf("Write: %v", err) } var actual_msg = make([]byte, 512) - n, err := ws.Read(actual_msg) + n, err := conn.Read(actual_msg) if err != nil { t.Errorf("Read: %v", err) } @@ -78,10 +84,10 @@ func TestEcho(t *testing.T) { if !bytes.Equal(msg, actual_msg) { t.Errorf("Echo: expected %q got %q", msg, actual_msg) } - ws.Close() + conn.Close() } -func TestEchoDraft75(t *testing.T) { +func TestAddr(t *testing.T) { once.Do(startServer) // websocket.Dial() @@ -89,27 +95,64 @@ func TestEchoDraft75(t *testing.T) { if err != nil { t.Fatal("dialing", err) } - ws, err := newClient("/echoDraft75", "localhost", "http://localhost", - "ws://localhost/echoDraft75", "", client, draft75handshake) + conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { - t.Errorf("WebSocket handshake: %v", err) + t.Errorf("WebSocket handshake error: %v", err) return } - msg := []byte("hello, world\n") - if _, err := ws.Write(msg); err != nil { - t.Errorf("Write: error %v", err) + ra := conn.RemoteAddr().String() + if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") { + t.Errorf("Bad remote addr: %v", ra) } - var actual_msg = make([]byte, 512) - n, err := ws.Read(actual_msg) + la := conn.LocalAddr().String() + if !strings.HasPrefix(la, "http://") { + t.Errorf("Bad local addr: %v", la) + } + conn.Close() +} + +func TestCount(t *testing.T) { + once.Do(startServer) + + // websocket.Dial() + client, err := net.Dial("tcp", serverAddr) if err != nil { - t.Errorf("Read: error %v", err) + t.Fatal("dialing", err) } - actual_msg = actual_msg[0:n] - if !bytes.Equal(msg, actual_msg) { - t.Errorf("Echo: expected %q got %q", msg, actual_msg) + conn, err := NewClient(newConfig(t, "/count"), client) + if err != nil { + t.Errorf("WebSocket handshake error: %v", err) + return } - ws.Close() + + var count Count + count.S = "hello" + if err := JSON.Send(conn, count); err != nil { + t.Errorf("Write: %v", err) + } + if err := JSON.Receive(conn, &count); err != nil { + t.Errorf("Read: %v", err) + } + if count.N != 1 { + t.Errorf("count: expected %d got %d", 1, count.N) + } + if count.S != "hello" { + t.Errorf("count: expected %q got %q", "hello", count.S) + } + if err := JSON.Send(conn, count); err != nil { + t.Errorf("Write: %v", err) + } + if err := JSON.Receive(conn, &count); err != nil { + t.Errorf("Read: %v", err) + } + if count.N != 2 { + t.Errorf("count: expected %d got %d", 2, count.N) + } + if count.S != "hellohello" { + t.Errorf("count: expected %q got %q", "hellohello", count.S) + } + conn.Close() } func TestWithQuery(t *testing.T) { @@ -120,8 +163,13 @@ func TestWithQuery(t *testing.T) { t.Fatal("dialing", err) } - ws, err := newClient("/echo?q=v", "localhost", "http://localhost", - "ws://localhost/echo?q=v", "", client, handshake) + config := newConfig(t, "/echo") + config.Location, err = url.ParseRequest(fmt.Sprintf("ws://%s/echo?q=v", serverAddr)) + if err != nil { + t.Fatal("location url", err) + } + + ws, err := NewClient(config, client) if err != nil { t.Errorf("WebSocket handshake: %v", err) return @@ -137,8 +185,10 @@ func TestWithProtocol(t *testing.T) { t.Fatal("dialing", err) } - ws, err := newClient("/echo", "localhost", "http://localhost", - "ws://localhost/echo", "test", client, handshake) + config := newConfig(t, "/echo") + config.Protocol = append(config.Protocol, "test") + + ws, err := NewClient(config, client) if err != nil { t.Errorf("WebSocket handshake: %v", err) return @@ -167,29 +217,17 @@ func TestHTTP(t *testing.T) { } } -func TestHTTPDraft75(t *testing.T) { - once.Do(startServer) - - r, err := http.Get(fmt.Sprintf("http://%s/echoDraft75", serverAddr)) - if err != nil { - t.Errorf("Get: error %#v", err) - return - } - if r.StatusCode != http.StatusBadRequest { - t.Errorf("Get: got status %d", r.StatusCode) - } -} - func TestTrailingSpaces(t *testing.T) { // http://code.google.com/p/go/issues/detail?id=955 // The last runs of this create keys with trailing spaces that should not be // generated by the client. once.Do(startServer) + config := newConfig(t, "/echo") for i := 0; i < 30; i++ { // body - ws, err := Dial(fmt.Sprintf("ws://%s/echo", serverAddr), "", "http://localhost/") + ws, err := DialConfig(config) if err != nil { - t.Error("Dial failed:", err.String()) + t.Errorf("Dial #%d failed: %v", i, err) break } ws.Close() @@ -206,19 +244,18 @@ func TestSmallBuffer(t *testing.T) { if err != nil { t.Fatal("dialing", err) } - ws, err := newClient("/echo", "localhost", "http://localhost", - "ws://localhost/echo", "", client, handshake) + conn, err := NewClient(newConfig(t, "/echo"), client) if err != nil { t.Errorf("WebSocket handshake error: %v", err) return } msg := []byte("hello, world\n") - if _, err := ws.Write(msg); err != nil { + if _, err := conn.Write(msg); err != nil { t.Errorf("Write: %v", err) } var small_msg = make([]byte, 8) - n, err := ws.Read(small_msg) + n, err := conn.Read(small_msg) if err != nil { t.Errorf("Read: %v", err) } @@ -226,7 +263,7 @@ func TestSmallBuffer(t *testing.T) { t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) } var second_msg = make([]byte, len(msg)) - n, err = ws.Read(second_msg) + n, err = conn.Read(second_msg) if err != nil { t.Errorf("Read: %v", err) } @@ -234,38 +271,5 @@ func TestSmallBuffer(t *testing.T) { if !bytes.Equal(msg[len(small_msg):], second_msg) { t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) } - ws.Close() - -} - -func testSkipLengthFrame(t *testing.T) { - b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'} - buf := bytes.NewBuffer(b) - br := bufio.NewReader(buf) - bw := bufio.NewWriter(buf) - ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil) - msg := make([]byte, 5) - n, err := ws.Read(msg) - if err != nil { - t.Errorf("Read: %v", err) - } - if !bytes.Equal(b[4:8], msg[0:n]) { - t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n]) - } -} - -func testSkipNoUTF8Frame(t *testing.T) { - b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'} - buf := bytes.NewBuffer(b) - br := bufio.NewReader(buf) - bw := bufio.NewWriter(buf) - ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil) - msg := make([]byte, 5) - n, err := ws.Read(msg) - if err != nil { - t.Errorf("Read: %v", err) - } - if !bytes.Equal(b[4:8], msg[0:n]) { - t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n]) - } + conn.Close() } |