socket.io library. current version.

This commit is contained in:
Gennadii Kovalev 2016-05-18 22:11:10 +02:00
parent 9322a52075
commit 52ee6ec03c
14 changed files with 1608 additions and 0 deletions

157
README.md Normal file
View File

@ -0,0 +1,157 @@
golang socket.io
================
golang implementation of [socket.io](http://socket.io) library, client and server
You can check working chat server, based on this library, at (funstream)[http://funstream.tv]
Examples directory contains simple client and server.
### Installation
go get github.com/graarh/golang-socketio
### Simple server usage
```go
//create
server := gosocketio.NewServer(transport.GetDefaultWebsocketTransport())
//handle connected
server.On(gosocketio.OnConnection, func(c *gosocketio.Channel, args interface{}) {
log.Println("New client connected")
//join them to room
c.Join("chat")
})
type Message struct {
Name string `json:"name"`
Message string `json:"message"`
}
//handle custom event
server.On("send", func(c *gosocketio.Channel, msg Message) string {
//send event to all in room
c.BroadcastTo("chat", "message", msg)
return "OK"
})
//setup http server
serveMux := http.NewServeMux()
serveMux.Handle("/socket.io/", server)
log.Panic(http.ListenAndServe(":80", serveMux))
```
### Javascript client for this server
```javascript
var socket = io('ws://yourdomain.com', {transports: ['websocket']});
// listen for messages
socket.on('message', function(message) {
console.log('new message');
console.log(message);
});
socket.on('connect', function () {
console.log('socket connected');
//send something
socket.emit('send', {name: "my name", message: "hello"}, function(result) {
console.log('sended successfully);
console.log(result)
});
});
```
### Server, detailed usage
```go
//create server instance, you can setup transport parameters or get the default one
//look at websocket.go for parameters description
server := gosocketio.NewServer(transport.GetDefaultWebsocketTransport())
// --- this is default handlers
//on connection handler, occurs once for each connected client
server.On(gosocketio.OnConnection, func(c *gosocketio.Channel, args interface{}) {
//client id is unique
log.Println("New client connected, client id is ", c.Id())
//you can join clients to rooms
c.Join("room name")
//of course, you can list the clients in the room, or account them
channels, _ := c.List(data.Channel)
log.Println(len(channels), "clients in room")
})
//on disconnection handler, if client hangs connection unexpectedly, it will still occurs
server.On(gosocketio.OnDisconnection, func(c *gosocketio.Channel, args interface{}) {
//this is not necessary, client will be removed from rooms
//automatically on disconnect
//but you can remove client from room whenever you need to
c.Leave("room name")
log.Println("Disconnected")
})
//error catching handler
server.On(gosocketio.OnError, func(c *gosocketio.Channel, args interface{}) {
log.Println("Error occurs")
})
// --- this is custom handler
//custom event handler
server.On("handle something", func(c *gosocketio.Channel, channel Channel) string {
log.Println("Something successfully handled")
//you can return result of handler, in this case
//handler will be converted from "emit" to "ack"
return "result"
})
//you can get client connection by it's id
channel, _ := server.GetChannel("client id here")
//and send the event to the client
type MyEventData struct {
Data: string
}
channel.Emit("my event", MyEventData{"my data"})
//or you can send ack to client and get result back
result, err := channel.Ack("my custom ack", MyEventData{"ack data"}, time.Second * 5)
//you can broadcast to all clients
server.BroadcastToAll("my event", MyEventData{"broadcast"})
//or for clients joined to room
server.BroadcastTo("my room", "my event", MyEventData{"room broadcast"})
//setup http server like this for handling connections
serveMux := http.NewServeMux()
serveMux.Handle("/socket.io/", server)
log.Panic(http.ListenAndServe(":80", serveMux))
```
### Client
```go
//connect to server, you can use your own transport settings
c, err := gosocketio.Dial("localhost:80", transport.GetDefaultWebsocketTransport())
//do something, handlers and functions are same as server ones
//close connection
c.Close()
```
### Roadmap
1. Tests
2. Travis CI
3. http longpoll transport
4. pure http (short-timed queries) transport
5. binary format

64
ack.go Normal file
View File

@ -0,0 +1,64 @@
package gosocketio
import (
"errors"
"sync"
)
var (
ErrorWaiterNotFound = errors.New("Waiter not found")
)
/**
Processes functions that require answers, also known as acknowledge or ack
*/
type ackProcessor struct {
counter int
counterLock sync.Mutex
resultWaiters map[int](chan string)
resultWaitersLock sync.RWMutex
}
/**
get next id of ack call
*/
func (a *ackProcessor) getNextId() int {
a.counterLock.Lock()
defer a.counterLock.Unlock()
a.counter++
return a.counter
}
/**
Just before the ack function called, the waiter should be added
to wait and receive response to ack call
*/
func (a *ackProcessor) addWaiter(id int, w chan string) {
a.resultWaitersLock.Lock()
a.resultWaiters[id] = w
a.resultWaitersLock.Unlock()
}
/**
removes waiter that is unnecessary anymore
*/
func (a *ackProcessor) removeWaiter(id int) {
a.resultWaitersLock.Lock()
delete(a.resultWaiters, id)
a.resultWaitersLock.Unlock()
}
/**
check if waiter with given ack id is exists, and returns it
*/
func (a *ackProcessor) getWaiter(id int) (chan string, error) {
a.resultWaitersLock.RLock()
defer a.resultWaitersLock.RUnlock()
if waiter, ok := a.resultWaiters[id]; ok {
return waiter, nil
}
return nil, ErrorWaiterNotFound
}

60
caller.go Normal file
View File

@ -0,0 +1,60 @@
package gosocketio
import (
"errors"
"reflect"
)
type caller struct {
Func reflect.Value
Args reflect.Type
Out bool
}
var (
ErrorCallerNotFunc = errors.New("f is not function")
ErrorCallerNot2Args = errors.New("f should have 2 args")
ErrorCallerMaxOneValue = errors.New("f should return not more than one value")
)
/**
Parses function passed by using reflection, and stores its representation
for further call on message or ack
*/
func newCaller(f interface{}) (*caller, error) {
fVal := reflect.ValueOf(f)
if fVal.Kind() != reflect.Func {
return nil, ErrorCallerNotFunc
}
fType := fVal.Type()
if fType.NumIn() != 2 {
return nil, ErrorCallerNot2Args
}
if fType.NumOut() > 1 {
return nil, ErrorCallerMaxOneValue
}
return &caller{
Func: fVal,
Args: fType.In(1),
Out: fType.NumOut() == 1,
}, nil
}
/**
returns function parameter as it is present in it using reflection
*/
func (c *caller) getArgs() interface{} {
return reflect.New(c.Args).Interface()
}
/**
calls function with given arguments from its representation using reflection
*/
func (c *caller) callFunc(h *Channel, args interface{}) []reflect.Value {
a := []reflect.Value{reflect.ValueOf(h), reflect.ValueOf(args).Elem()}
return c.Func.Call(a)
}

46
client.go Normal file
View File

@ -0,0 +1,46 @@
package gosocketio
import (
"funstream/libs/socket.io/transport"
)
const (
webSocketProtocol = "ws://"
socketioUrl = "/socket.io/?EIO=3&transport=websocket"
)
/**
Socket.io client representation
*/
type Client struct {
methods
Channel
}
/**
connect to host and initialise socket.io protocol
*/
func Dial(host string, tr transport.Transport) (*Client, error) {
c := &Client{}
c.initChannel()
c.initMethods()
var err error
c.conn, err = tr.Connect(host)
if err != nil {
return nil, err
}
go inLoop(&c.Channel, &c.methods)
go outLoop(&c.Channel, &c.methods)
go pinger(&c.Channel)
return c, nil
}
/**
Close client connection
*/
func (c *Client) Close() {
CloseChannel(&c.Channel, &c.methods)
}

72
examples/client.go Normal file
View File

@ -0,0 +1,72 @@
package main
import (
"funstream/libs/socket.io"
"funstream/libs/socket.io/transport"
"log"
"runtime"
"time"
)
type Channel struct {
Channel string `json:"channel"`
}
type Message struct {
Id int `json:"id"`
Channel string `json:"channel"`
Text string `json:"text"`
}
func sendJoin(c *gosocketio.Client) {
log.Println("Acking /join")
result, err := c.Ack("/join", Channel{"main"}, time.Second*5)
if err != nil {
log.Fatal(err)
} else {
log.Println("Ack result to /join: ", result)
}
}
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
c, err := gosocketio.Dial("localhost:3811", transport.GetDefaultWebsocketTransport())
if err != nil {
log.Fatal(err)
}
err = c.On("/message", func(h *gosocketio.Channel, args Message) {
log.Println("--- Got chat message: ", args)
})
if err != nil {
log.Fatal(err)
}
err = c.On(gosocketio.OnDisconnection, func(h *gosocketio.Channel, args interface{}) {
log.Fatal("Disconnected")
})
if err != nil {
log.Fatal(err)
}
err = c.On(gosocketio.OnConnection, func(h *gosocketio.Channel, args interface{}) {
log.Println("Connected")
})
if err != nil {
log.Fatal(err)
}
time.Sleep(1 * time.Second)
go sendJoin(c)
go sendJoin(c)
go sendJoin(c)
go sendJoin(c)
go sendJoin(c)
time.Sleep(60 * time.Second)
c.Close()
log.Println(" [x] Complete")
}

46
examples/server.go Normal file
View File

@ -0,0 +1,46 @@
package main
import (
"funstream/libs/socket.io"
"funstream/libs/socket.io/transport"
"log"
"net/http"
"time"
)
type Channel struct {
Channel string `json:"channel"`
}
type Message struct {
Id int `json:"id"`
Channel string `json:"channel"`
Text string `json:"text"`
}
func main() {
server := gosocketio.NewServer(transport.GetDefaultWebsocketTransport())
server.On(gosocketio.OnConnection, func(c *gosocketio.Channel, args interface{}) {
log.Println("Connected")
c.Emit("/message", Message{10, "main", "using emit"})
c.Join("test")
c.BroadcastTo("test", "/message", Message{10, "main", "using broadcast"})
})
server.On(gosocketio.OnDisconnection, func(c *gosocketio.Channel, args interface{}) {
log.Println("Disconnected")
})
server.On("/join", func(c *gosocketio.Channel, channel Channel) string {
time.Sleep(2 * time.Second)
return "joined to " + channel.Channel
})
serveMux := http.NewServeMux()
serveMux.Handle("/socket.io/", server)
log.Println("Starting server...")
log.Panic(http.ListenAndServe(":3811", serveMux))
}

128
handler.go Normal file
View File

@ -0,0 +1,128 @@
package gosocketio
import (
"encoding/json"
"funstream/libs/socket.io/protocol"
"sync"
)
const (
OnConnection = "connection"
OnDisconnection = "disconnection"
OnError = "error"
)
/**
System handler function for internal event processing
*/
type systemHandler func(c *Channel)
/**
Contains maps of message processing functions
*/
type methods struct {
messageHandlers map[string]*caller
messageHandlersLock sync.RWMutex
onConnection systemHandler
onDisconnection systemHandler
}
/**
create messageHandlers map
*/
func (m *methods) initMethods() {
m.messageHandlers = make(map[string]*caller)
}
/**
Add message processing function, and bind it to given method
*/
func (m *methods) On(method string, f interface{}) error {
c, err := newCaller(f)
if err != nil {
return err
}
m.messageHandlersLock.Lock()
defer m.messageHandlersLock.Unlock()
m.messageHandlers[method] = c
return nil
}
/**
Find message processing function associated with given method
*/
func (m *methods) findMethod(method string) (*caller, bool) {
m.messageHandlersLock.RLock()
defer m.messageHandlersLock.RUnlock()
f, ok := m.messageHandlers[method]
return f, ok
}
func (m *methods) callLoopEvent(c *Channel, event string) {
if m.onConnection != nil && event == OnConnection {
m.onConnection(c)
}
if m.onDisconnection != nil && event == OnDisconnection {
m.onDisconnection(c)
}
f, ok := m.findMethod(event)
if !ok {
return
}
f.callFunc(c, &struct{}{})
}
/**
Check incoming message
On ack_resp - look for waiter
On ack_req - look for processing function and send ack_resp
On emit - look for processing function
*/
func (m *methods) processIncomingMessage(c *Channel, msg *protocol.Message) {
switch msg.Type {
case protocol.MessageTypeEmit:
f, ok := m.findMethod(msg.Method)
if !ok {
return
}
data := f.getArgs()
err := json.Unmarshal([]byte(msg.Args), &data)
if err != nil {
return
}
f.callFunc(c, data)
case protocol.MessageTypeAckRequest:
f, ok := m.findMethod(msg.Method)
if !ok || !f.Out {
return
}
data := f.getArgs()
err := json.Unmarshal([]byte(msg.Args), &data)
if err != nil {
return
}
result := f.callFunc(c, data)
ack := &protocol.Message{
Type: protocol.MessageTypeAckResponse,
AckId: msg.AckId,
}
send(ack, c, result[0].Interface())
case protocol.MessageTypeAckResponse:
waiter, err := c.ack.getWaiter(msg.AckId)
if err == nil {
waiter <- msg.Args
}
}
}

185
loop.go Normal file
View File

@ -0,0 +1,185 @@
package gosocketio
import (
"encoding/json"
"errors"
"funstream/libs/socket.io/protocol"
"funstream/libs/socket.io/transport"
"sync"
"time"
)
const (
queueBufferSize = 500
)
var (
ErrorWrongHeader = errors.New("Wrong header")
)
/**
engine.io header to send or receive
*/
type Header struct {
Sid string `json:"sid"`
Upgrades []string `json:"upgrades"`
PingInterval int `json:"pingInterval"`
PingTimeout int `json:"pingTimeout"`
}
/**
socket.io connection handler
use IsAlive to check that handler is still working
use Dial to connect to websocket
use In and Out channels for message exchange
Close message means channel is closed
ping is automatic
*/
type Channel struct {
conn transport.Connection
out chan string
header Header
alive bool
aliveLock sync.Mutex
ack ackProcessor
server *Server
ip string
}
/**
create channel, map, and set active
*/
func (c *Channel) initChannel() {
//TODO: queueBufferSize from constant to server or client variable
c.out = make(chan string, queueBufferSize)
c.ack.resultWaiters = make(map[int](chan string))
c.alive = true
}
/**
Get id of current socket connection
*/
func (c *Channel) Id() string {
return c.header.Sid
}
/**
Checks that Channel is still alive
*/
func (c *Channel) IsAlive() bool {
return c.alive
}
/**
Close channel
*/
func CloseChannel(c *Channel, m *methods, args ...interface{}) error {
c.aliveLock.Lock()
defer c.aliveLock.Unlock()
if !c.alive {
//already closed
return nil
}
c.conn.Close()
c.alive = false
//clean outloop
for len(c.out) > 0 {
<-c.out
}
c.out <- protocol.CloseMessage
m.callLoopEvent(c, OnDisconnection)
return nil
}
//incoming messages loop, puts incoming messages to In channel
func inLoop(c *Channel, m *methods) error {
for {
pkg, err := c.conn.GetMessage()
if err != nil {
return CloseChannel(c, m, err)
}
msg, err := protocol.Decode(pkg)
if err != nil {
CloseChannel(c, m, protocol.ErrorWrongPacket)
}
switch msg.Type {
case protocol.MessageTypeOpen:
if err := json.Unmarshal([]byte(msg.Source[1:]), &c.header); err != nil {
CloseChannel(c, m, ErrorWrongHeader)
}
m.callLoopEvent(c, OnConnection)
case protocol.MessageTypePing:
c.out <- protocol.PongMessage
case protocol.MessageTypePong:
default:
go m.processIncomingMessage(c, msg)
}
}
return nil
}
var overflooded map[*Channel]struct{} = make(map[*Channel]struct{})
var overfloodedLock sync.Mutex
func AmountOfOverflooded() int64 {
overfloodedLock.Lock()
defer overfloodedLock.Unlock()
return int64(len(overflooded))
}
/**
outgoing messages loop, sends messages from channel to socket
*/
func outLoop(c *Channel, m *methods) error {
for {
outBufferLen := len(c.out)
if outBufferLen == queueBufferSize {
return CloseChannel(c, m, ErrorSocketOverflood)
} else if outBufferLen > int(queueBufferSize/2) {
overfloodedLock.Lock()
overflooded[c] = struct{}{}
overfloodedLock.Unlock()
} else {
overfloodedLock.Lock()
delete(overflooded, c)
overfloodedLock.Unlock()
}
msg := <-c.out
if msg == protocol.CloseMessage {
return nil
}
err := c.conn.WriteMessage(msg)
if err != nil {
return CloseChannel(c, m, err)
}
}
return nil
}
/**
Pinger sends ping messages for keeping connection alive
*/
func pinger(c *Channel) {
for {
interval, _ := c.conn.PingParams()
time.Sleep(interval)
if !c.IsAlive() {
return
}
c.out <- protocol.PingMessage
}
}

45
protocol/message.go Normal file
View File

@ -0,0 +1,45 @@
package protocol
const (
/**
Message with connection options
*/
MessageTypeOpen = iota
/**
Close connection and destroy all handle routines
*/
MessageTypeClose = iota
/**
Ping request message
*/
MessageTypePing = iota
/**
Pong response message
*/
MessageTypePong = iota
/**
Empty message
*/
MessageTypeEmpty = iota
/**
Emit request, no response
*/
MessageTypeEmit = iota
/**
Emit request, wait for response (ack)
*/
MessageTypeAckRequest = iota
/**
ack response
*/
MessageTypeAckResponse = iota
)
type Message struct {
Type int
AckId int
Method string
Args string
Source string
}

214
protocol/socketio.go Normal file
View File

@ -0,0 +1,214 @@
package protocol
import (
"encoding/json"
"errors"
"strconv"
"strings"
)
const (
open = "0"
msg = "4"
emptyMessage = "40"
commonMessage = "42"
ackMessage = "43"
CloseMessage = "1"
PingMessage = "2"
PongMessage = "3"
)
var (
ErrorWrongMessageType = errors.New("Wrong message type")
ErrorWrongPacket = errors.New("Wrong packet")
)
func typeToText(msgType int) (string, error) {
switch msgType {
case MessageTypeOpen:
return open, nil
case MessageTypeClose:
return CloseMessage, nil
case MessageTypePing:
return PingMessage, nil
case MessageTypePong:
return PongMessage, nil
case MessageTypeEmpty:
return emptyMessage, nil
case MessageTypeEmit, MessageTypeAckRequest:
return commonMessage, nil
case MessageTypeAckResponse:
return ackMessage, nil
}
return "", ErrorWrongMessageType
}
func Encode(msg *Message) (string, error) {
result, err := typeToText(msg.Type)
if err != nil {
return "", err
}
if msg.Type == MessageTypeEmpty || msg.Type == MessageTypePing ||
msg.Type == MessageTypePong {
return result, nil
}
if msg.Type == MessageTypeAckRequest || msg.Type == MessageTypeAckResponse {
result += strconv.Itoa(msg.AckId)
}
if msg.Type == MessageTypeOpen || msg.Type == MessageTypeClose {
return result + msg.Args, nil
}
if msg.Type == MessageTypeAckResponse {
return result + "[" + msg.Args + "]", nil
}
jsonMethod, err := json.Marshal(&msg.Method)
if err != nil {
return "", err
}
return result + "[" + string(jsonMethod) + "," + msg.Args + "]", nil
}
func MustEncode(msg *Message) string {
result, err := Encode(msg)
if err != nil {
panic(err)
}
return result
}
func getMessageType(data string) (int, error) {
if len(data) == 0 {
return 0, ErrorWrongMessageType
}
switch data[0:1] {
case open:
return MessageTypeOpen, nil
case CloseMessage:
return MessageTypeClose, nil
case PingMessage:
return MessageTypePing, nil
case PongMessage:
return MessageTypePong, nil
case msg:
if len(data) == 1 {
return 0, ErrorWrongMessageType
}
switch data[0:2] {
case emptyMessage:
return MessageTypeEmpty, nil
case commonMessage:
return MessageTypeAckRequest, nil
case ackMessage:
return MessageTypeAckResponse, nil
}
}
return 0, ErrorWrongMessageType
}
/**
Get ack id of current packet, if present
*/
func getAck(text string) (ackId int, restText string, err error) {
if len(text) < 4 {
return 0, "", ErrorWrongPacket
}
text = text[2:]
pos := strings.IndexByte(text, '[')
if pos == -1 {
return 0, "", ErrorWrongPacket
}
ack, err := strconv.Atoi(text[0:pos])
if err != nil {
return 0, "", err
}
return ack, text[pos:], nil
}
/**
Get message method of current packet, if present
*/
func getMethod(text string) (method, restText string, err error) {
var start, end, rest, countQuote int
for i, c := range text {
if c == '"' {
switch countQuote {
case 0:
start = i + 1
case 1:
end = i
rest = i + 1
default:
return "", "", ErrorWrongPacket
}
countQuote++
}
if c == ',' {
if countQuote < 2 {
continue
}
rest = i + 1
break
}
}
if (end < start) || (rest >= len(text)) {
return "", "", ErrorWrongPacket
}
return text[start:end], text[rest : len(text)-1], nil
}
func Decode(data string) (*Message, error) {
var err error
msg := &Message{}
msg.Source = data
msg.Type, err = getMessageType(data)
if err != nil {
return nil, err
}
if msg.Type == MessageTypeOpen {
msg.Args = data[1:]
return msg, nil
}
if msg.Type == MessageTypeClose || msg.Type == MessageTypePing ||
msg.Type == MessageTypePong || msg.Type == MessageTypeEmpty {
return msg, nil
}
ack, rest, err := getAck(data)
msg.AckId = ack
if msg.Type == MessageTypeAckResponse {
if err != nil {
return nil, err
}
msg.Args = rest[1 : len(rest)-1]
return msg, nil
}
if err != nil {
msg.Type = MessageTypeEmit
rest = data[2:]
}
msg.Method, msg.Args, err = getMethod(rest)
if err != nil {
return nil, err
}
return msg, nil
}

77
send.go Normal file
View File

@ -0,0 +1,77 @@
package gosocketio
import (
"encoding/json"
"errors"
"time"
"funstream/libs/socket.io/protocol"
)
var (
ErrorSendTimeout = errors.New("Timeout")
ErrorSocketOverflood = errors.New("Socket overflood")
)
/**
Send message packet to socket
*/
func send(msg *protocol.Message, c *Channel, args interface{}) error {
json, err := json.Marshal(&args)
if err != nil {
return err
}
msg.Args = string(json)
command, err := protocol.Encode(msg)
if err != nil {
return err
}
if len(c.out) == queueBufferSize {
return ErrorSocketOverflood
}
c.out <- command
return nil
}
/**
Create packet based on given data and send it
*/
func (c *Channel) Emit(method string, args interface{}) error {
msg := &protocol.Message{
Type: protocol.MessageTypeEmit,
Method: method,
}
return send(msg, c, args)
}
/**
Create ack packet based on given data and send it and receive response
*/
func (c *Channel) Ack(method string, args interface{}, timeout time.Duration) (string, error) {
msg := &protocol.Message{
Type: protocol.MessageTypeAckRequest,
AckId: c.ack.getNextId(),
Method: method,
}
waiter := make(chan string)
c.ack.addWaiter(msg.AckId, waiter)
err := send(msg, c, args)
if err != nil {
c.ack.removeWaiter(msg.AckId)
}
select {
case result := <-waiter:
return result, nil
case <-time.After(timeout):
c.ack.removeWaiter(msg.AckId)
return "", ErrorSendTimeout
}
}

325
server.go Normal file
View File

@ -0,0 +1,325 @@
package gosocketio
import (
"bytes"
"crypto/md5"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"funstream/libs/socket.io/transport"
"math/rand"
"net/http"
"sync"
"time"
"funstream/libs/socket.io/protocol"
)
var (
ErrorServerNotSet = errors.New("Server not set")
ErrorConnectionNotFound = errors.New("Connection not found")
)
/**
socket.io server instance
*/
type Server struct {
methods
http.Handler
channels map[string]map[*Channel]struct{}
rooms map[*Channel]map[string]struct{}
channelsLock sync.RWMutex
sids map[string]*Channel
sidsLock sync.RWMutex
tr transport.Transport
}
/**
Get channel by it's sid
*/
func (s *Server) GetChannel(sid string) (*Channel, error) {
s.sidsLock.RLock()
defer s.sidsLock.RUnlock()
c, ok := s.sids[sid]
if !ok {
return nil, ErrorConnectionNotFound
}
return c, nil
}
/**
Join this channel to given room
*/
func (c *Channel) Join(room string) error {
if c.server == nil {
return ErrorServerNotSet
}
c.server.channelsLock.Lock()
defer c.server.channelsLock.Unlock()
cn := c.server.channels
if _, ok := cn[room]; !ok {
cn[room] = make(map[*Channel]struct{})
}
byRoom := c.server.rooms
if _, ok := byRoom[c]; !ok {
byRoom[c] = make(map[string]struct{})
}
cn[room][c] = struct{}{}
byRoom[c][room] = struct{}{}
return nil
}
/**
Remove this channel from given room
*/
func (c *Channel) Leave(room string) error {
if c.server == nil {
return ErrorServerNotSet
}
c.server.channelsLock.Lock()
defer c.server.channelsLock.Unlock()
cn := c.server.channels
if _, ok := cn[room]; ok {
delete(cn[room], c)
if len(cn[room]) == 0 {
delete(cn, room)
}
}
byRoom := c.server.rooms
if _, ok := byRoom[c]; ok {
delete(byRoom[c], room)
}
return nil
}
/**
Get list of channels, joined to given room, using channel
*/
func (c *Channel) List(room string) ([]*Channel, error) {
if c.server == nil {
return nil, ErrorServerNotSet
}
return c.server.List(room)
}
/**
Get list of channels, joined to given room, using server
*/
func (s *Server) List(room string) ([]*Channel, error) {
s.channelsLock.RLock()
defer s.channelsLock.RUnlock()
roomChannels, ok := s.channels[room]
if !ok {
return []*Channel{}, nil
}
i := 0
roomChannelsCopy := make([]*Channel, len(roomChannels))
for channel := range roomChannels {
roomChannelsCopy[i] = channel
i++
}
return roomChannelsCopy, nil
}
func (c *Channel) BroadcastTo(room, method string, args interface{}) {
if c.server == nil {
return
}
c.server.BroadcastTo(room, method, args)
}
/**
Broadcast message to all room channels
*/
func (s *Server) BroadcastTo(room, method string, args interface{}) {
s.channelsLock.RLock()
defer s.channelsLock.RUnlock()
roomChannels, ok := s.channels[room]
if !ok {
return
}
for cn := range roomChannels {
if cn.IsAlive() {
go cn.Emit(method, args)
}
}
}
/**
Broadcast to all clients
*/
func (s *Server) BroadcastToAll(method string, args interface{}) {
s.sidsLock.RLock()
defer s.sidsLock.RUnlock()
for _, cn := range s.sids {
if cn.IsAlive() {
go cn.Emit(method, args)
}
}
}
/**
Generate new id for socket.io connection
*/
func generateNewId(custom string) string {
hash := fmt.Sprintf("%s %s %n %n", custom, time.Now(), rand.Uint32(), rand.Uint32())
buf := bytes.NewBuffer(nil)
sum := md5.Sum([]byte(hash))
encoder := base64.NewEncoder(base64.URLEncoding, buf)
encoder.Write(sum[:])
encoder.Close()
return buf.String()[:20]
}
/**
On connection system handler, store sid
*/
func onConnectStore(c *Channel) {
c.server.sidsLock.Lock()
defer c.server.sidsLock.Unlock()
c.server.sids[c.Id()] = c
}
/**
On disconnection system handler, clean joins and sid
*/
func onDisconnectCleanup(c *Channel) {
c.server.channelsLock.Lock()
defer c.server.channelsLock.Unlock()
cn := c.server.channels
byRoom, ok := c.server.rooms[c]
if ok {
for room := range byRoom {
if curRoom, ok := cn[room]; ok {
delete(curRoom, c)
if len(curRoom) == 0 {
delete(cn, room)
}
}
}
delete(c.server.rooms, c)
}
c.server.sidsLock.Lock()
defer c.server.sidsLock.Unlock()
delete(c.server.sids, c.Id())
}
func (s *Server) SendOpenSequence(c *Channel) {
jsonHdr, err := json.Marshal(&c.header)
if err != nil {
panic(err)
}
c.out <- protocol.MustEncode(
&protocol.Message{
Type: protocol.MessageTypeOpen,
Args: string(jsonHdr),
},
)
c.out <- protocol.MustEncode(&protocol.Message{Type: protocol.MessageTypeEmpty})
}
/**
Setup event loop for given connection
*/
func (s *Server) SetupEventLoop(conn transport.Connection, remoteAddr string) {
interval, timeout := conn.PingParams()
hdr := Header{
Sid: generateNewId(remoteAddr),
Upgrades: []string{},
PingInterval: int(interval / time.Millisecond),
PingTimeout: int(timeout / time.Millisecond),
}
c := &Channel{}
c.conn = conn
c.ip = remoteAddr
c.initChannel()
c.server = s
c.header = hdr
s.SendOpenSequence(c)
go inLoop(c, &s.methods)
go outLoop(c, &s.methods)
s.callLoopEvent(c, OnConnection)
}
/**
implements ServeHTTP function from http.Handler
*/
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn, err := s.tr.HandleConnection(w, r)
if err != nil {
return
}
s.SetupEventLoop(conn, r.RemoteAddr)
s.tr.Serve(w, r)
}
/**
Get amount of current connected sids
*/
func (s *Server) AmountOfSids() int64 {
s.sidsLock.RLock()
defer s.sidsLock.RUnlock()
return int64(len(s.sids))
}
/**
Get amount of rooms with at least one channel(or sid) joined
*/
func (s *Server) AmountOfRooms() int64 {
s.channelsLock.RLock()
defer s.channelsLock.RUnlock()
return int64(len(s.channels))
}
/**
Create new socket.io server
*/
func NewServer(tr transport.Transport) *Server {
s := Server{}
s.initMethods()
s.tr = tr
s.channels = make(map[string]map[*Channel]struct{})
s.rooms = make(map[*Channel]map[string]struct{})
s.sids = make(map[string]*Channel)
s.onConnection = onConnectStore
s.onDisconnection = onDisconnectCleanup
return &s
}

51
transport/transport.go Normal file
View File

@ -0,0 +1,51 @@
package transport
import (
"net/http"
"time"
)
/**
End-point connection for given transport
*/
type Connection interface {
/**
Receive one more message, block until received
*/
GetMessage() (message string, err error)
/**
Send given message, block until sent
*/
WriteMessage(message string) error
/**
Close current connection
*/
Close()
/**
Get ping time interval and ping request timeout
*/
PingParams() (interval, timeout time.Duration)
}
/**
Connection factory for given transport
*/
type Transport interface {
/**
Get client connection
*/
Connect(host string) (conn Connection, err error)
/**
Handle one server connection
*/
HandleConnection(w http.ResponseWriter, r *http.Request) (conn Connection, err error)
/**
Serve HTTP request after making connection and events setup
*/
Serve(w http.ResponseWriter, r *http.Request)
}

138
transport/websocket.go Normal file
View File

@ -0,0 +1,138 @@
package transport
import (
"errors"
"github.com/gorilla/websocket"
"io/ioutil"
"net/http"
"time"
)
const (
webSocketProtocol = "ws://"
socketioUrl = "/socket.io/?EIO=3&transport=websocket"
upgradeFailed = "Upgrade failed: "
WsDefaultPingInterval = 30 * time.Second
WsDefaultPingTimeout = 60 * time.Second
WsDefaultReceiveTimeout = 60 * time.Second
WsDefaultSendTimeout = 60 * time.Second
WsDefaultBufferSize = 1024 * 32
)
var (
ErrorBinaryMessage = errors.New("Binary messages are not supported")
ErrorBadBuffer = errors.New("Buffer error")
ErrorPacketWrong = errors.New("Wrong packet type error")
ErrorMethodNotAllowed = errors.New("Method not allowed")
ErrorHttpUpgradeFailed = errors.New("Http upgrade failed")
)
type WebsocketConnection struct {
socket *websocket.Conn
transport *WebsocketTransport
}
func (wsc *WebsocketConnection) GetMessage() (message string, err error) {
wsc.socket.SetReadDeadline(time.Now().Add(wsc.transport.ReceiveTimeout))
msgType, reader, err := wsc.socket.NextReader()
if err != nil {
return "", err
}
//support only text messages exchange
if msgType != websocket.TextMessage {
return "", ErrorBinaryMessage
}
data, err := ioutil.ReadAll(reader)
if err != nil {
return "", ErrorBadBuffer
}
text := string(data)
//empty messages are not allowed
if len(text) == 0 {
return "", ErrorPacketWrong
}
return text, nil
}
func (wsc *WebsocketConnection) WriteMessage(message string) error {
wsc.socket.SetWriteDeadline(time.Now().Add(wsc.transport.SendTimeout))
writer, err := wsc.socket.NextWriter(websocket.TextMessage)
if err != nil {
return err
}
if _, err := writer.Write([]byte(message)); err != nil {
return err
}
if err := writer.Close(); err != nil {
return err
}
return nil
}
func (wsc *WebsocketConnection) Close() {
wsc.socket.Close()
}
func (wsc *WebsocketConnection) PingParams() (interval, timeout time.Duration) {
return wsc.transport.PingInterval, wsc.transport.PingTimeout
}
type WebsocketTransport struct {
PingInterval time.Duration
PingTimeout time.Duration
ReceiveTimeout time.Duration
SendTimeout time.Duration
BufferSize int
}
func (wst *WebsocketTransport) Connect(host string) (conn Connection, err error) {
dialer := websocket.Dialer{}
socket, _, err := dialer.Dial(webSocketProtocol+host+socketioUrl, nil)
if err != nil {
return nil, err
}
return &WebsocketConnection{socket, wst}, nil
}
func (wst *WebsocketTransport) HandleConnection(
w http.ResponseWriter, r *http.Request) (conn Connection, err error) {
if r.Method != "GET" {
http.Error(w, upgradeFailed+ErrorMethodNotAllowed.Error(), 503)
return nil, ErrorMethodNotAllowed
}
socket, err := websocket.Upgrade(w, r, nil, wst.BufferSize, wst.BufferSize)
if err != nil {
http.Error(w, upgradeFailed+err.Error(), 503)
return nil, ErrorHttpUpgradeFailed
}
return &WebsocketConnection{socket, wst}, nil
}
/**
Websocket connection do not require any additional processing
*/
func (wst *WebsocketTransport) Serve(w http.ResponseWriter, r *http.Request) {}
/**
Returns websocket connection with default params
*/
func GetDefaultWebsocketTransport() *WebsocketTransport {
return &WebsocketTransport{
PingInterval: WsDefaultPingInterval,
PingTimeout: WsDefaultPingTimeout,
ReceiveTimeout: WsDefaultReceiveTimeout,
SendTimeout: WsDefaultSendTimeout,
BufferSize: WsDefaultBufferSize,
}
}