diff --git a/README.md b/README.md new file mode 100644 index 0000000..84efd90 --- /dev/null +++ b/README.md @@ -0,0 +1,157 @@ +golang socket.io +================ + +golang implementation of [socket.io](http://socket.io) library, client and server + +You can check working chat server, based on this library, at (funstream)[http://funstream.tv] + +Examples directory contains simple client and server. + +### Installation + + go get github.com/graarh/golang-socketio + +### Simple server usage + +```go + //create + server := gosocketio.NewServer(transport.GetDefaultWebsocketTransport()) + + //handle connected + server.On(gosocketio.OnConnection, func(c *gosocketio.Channel, args interface{}) { + log.Println("New client connected") + //join them to room + c.Join("chat") + }) + + type Message struct { + Name string `json:"name"` + Message string `json:"message"` + } + + //handle custom event + server.On("send", func(c *gosocketio.Channel, msg Message) string { + //send event to all in room + c.BroadcastTo("chat", "message", msg) + return "OK" + }) + + //setup http server + serveMux := http.NewServeMux() + serveMux.Handle("/socket.io/", server) + log.Panic(http.ListenAndServe(":80", serveMux)) +``` + +### Javascript client for this server + +```javascript +var socket = io('ws://yourdomain.com', {transports: ['websocket']}); + + // listen for messages + socket.on('message', function(message) { + + console.log('new message'); + console.log(message); + }); + + socket.on('connect', function () { + + console.log('socket connected'); + + //send something + socket.emit('send', {name: "my name", message: "hello"}, function(result) { + + console.log('sended successfully); + console.log(result) + }); + }); +``` + +### Server, detailed usage + +```go + //create server instance, you can setup transport parameters or get the default one + //look at websocket.go for parameters description + server := gosocketio.NewServer(transport.GetDefaultWebsocketTransport()) + + // --- this is default handlers + + //on connection handler, occurs once for each connected client + server.On(gosocketio.OnConnection, func(c *gosocketio.Channel, args interface{}) { + //client id is unique + log.Println("New client connected, client id is ", c.Id()) + + //you can join clients to rooms + c.Join("room name") + + //of course, you can list the clients in the room, or account them + channels, _ := c.List(data.Channel) + log.Println(len(channels), "clients in room") + }) + //on disconnection handler, if client hangs connection unexpectedly, it will still occurs + server.On(gosocketio.OnDisconnection, func(c *gosocketio.Channel, args interface{}) { + //this is not necessary, client will be removed from rooms + //automatically on disconnect + //but you can remove client from room whenever you need to + c.Leave("room name") + + log.Println("Disconnected") + }) + //error catching handler + server.On(gosocketio.OnError, func(c *gosocketio.Channel, args interface{}) { + log.Println("Error occurs") + }) + + // --- this is custom handler + + //custom event handler + server.On("handle something", func(c *gosocketio.Channel, channel Channel) string { + log.Println("Something successfully handled") + + //you can return result of handler, in this case + //handler will be converted from "emit" to "ack" + return "result" + }) + + //you can get client connection by it's id + channel, _ := server.GetChannel("client id here") + //and send the event to the client + type MyEventData struct { + Data: string + } + channel.Emit("my event", MyEventData{"my data"}) + + //or you can send ack to client and get result back + result, err := channel.Ack("my custom ack", MyEventData{"ack data"}, time.Second * 5) + + //you can broadcast to all clients + server.BroadcastToAll("my event", MyEventData{"broadcast"}) + + //or for clients joined to room + server.BroadcastTo("my room", "my event", MyEventData{"room broadcast"}) + + //setup http server like this for handling connections + serveMux := http.NewServeMux() + serveMux.Handle("/socket.io/", server) + log.Panic(http.ListenAndServe(":80", serveMux)) +``` + +### Client + +```go + //connect to server, you can use your own transport settings + c, err := gosocketio.Dial("localhost:80", transport.GetDefaultWebsocketTransport()) + + //do something, handlers and functions are same as server ones + + //close connection + c.Close() +``` + +### Roadmap + +1. Tests +2. Travis CI +3. http longpoll transport +4. pure http (short-timed queries) transport +5. binary format \ No newline at end of file diff --git a/ack.go b/ack.go new file mode 100644 index 0000000..e889dfc --- /dev/null +++ b/ack.go @@ -0,0 +1,64 @@ +package gosocketio + +import ( + "errors" + "sync" +) + +var ( + ErrorWaiterNotFound = errors.New("Waiter not found") +) + +/** +Processes functions that require answers, also known as acknowledge or ack +*/ +type ackProcessor struct { + counter int + counterLock sync.Mutex + + resultWaiters map[int](chan string) + resultWaitersLock sync.RWMutex +} + +/** +get next id of ack call +*/ +func (a *ackProcessor) getNextId() int { + a.counterLock.Lock() + defer a.counterLock.Unlock() + + a.counter++ + return a.counter +} + +/** +Just before the ack function called, the waiter should be added +to wait and receive response to ack call +*/ +func (a *ackProcessor) addWaiter(id int, w chan string) { + a.resultWaitersLock.Lock() + a.resultWaiters[id] = w + a.resultWaitersLock.Unlock() +} + +/** +removes waiter that is unnecessary anymore +*/ +func (a *ackProcessor) removeWaiter(id int) { + a.resultWaitersLock.Lock() + delete(a.resultWaiters, id) + a.resultWaitersLock.Unlock() +} + +/** +check if waiter with given ack id is exists, and returns it +*/ +func (a *ackProcessor) getWaiter(id int) (chan string, error) { + a.resultWaitersLock.RLock() + defer a.resultWaitersLock.RUnlock() + + if waiter, ok := a.resultWaiters[id]; ok { + return waiter, nil + } + return nil, ErrorWaiterNotFound +} diff --git a/caller.go b/caller.go new file mode 100644 index 0000000..7671947 --- /dev/null +++ b/caller.go @@ -0,0 +1,60 @@ +package gosocketio + +import ( + "errors" + "reflect" +) + +type caller struct { + Func reflect.Value + Args reflect.Type + Out bool +} + +var ( + ErrorCallerNotFunc = errors.New("f is not function") + ErrorCallerNot2Args = errors.New("f should have 2 args") + ErrorCallerMaxOneValue = errors.New("f should return not more than one value") +) + +/** +Parses function passed by using reflection, and stores its representation +for further call on message or ack +*/ +func newCaller(f interface{}) (*caller, error) { + fVal := reflect.ValueOf(f) + if fVal.Kind() != reflect.Func { + return nil, ErrorCallerNotFunc + } + + fType := fVal.Type() + if fType.NumIn() != 2 { + return nil, ErrorCallerNot2Args + } + + if fType.NumOut() > 1 { + return nil, ErrorCallerMaxOneValue + } + + return &caller{ + Func: fVal, + Args: fType.In(1), + Out: fType.NumOut() == 1, + }, nil +} + +/** +returns function parameter as it is present in it using reflection +*/ +func (c *caller) getArgs() interface{} { + return reflect.New(c.Args).Interface() +} + +/** +calls function with given arguments from its representation using reflection +*/ +func (c *caller) callFunc(h *Channel, args interface{}) []reflect.Value { + a := []reflect.Value{reflect.ValueOf(h), reflect.ValueOf(args).Elem()} + + return c.Func.Call(a) +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..ebe9590 --- /dev/null +++ b/client.go @@ -0,0 +1,46 @@ +package gosocketio + +import ( + "funstream/libs/socket.io/transport" +) + +const ( + webSocketProtocol = "ws://" + socketioUrl = "/socket.io/?EIO=3&transport=websocket" +) + +/** +Socket.io client representation +*/ +type Client struct { + methods + Channel +} + +/** +connect to host and initialise socket.io protocol +*/ +func Dial(host string, tr transport.Transport) (*Client, error) { + c := &Client{} + c.initChannel() + c.initMethods() + + var err error + c.conn, err = tr.Connect(host) + if err != nil { + return nil, err + } + + go inLoop(&c.Channel, &c.methods) + go outLoop(&c.Channel, &c.methods) + go pinger(&c.Channel) + + return c, nil +} + +/** +Close client connection +*/ +func (c *Client) Close() { + CloseChannel(&c.Channel, &c.methods) +} diff --git a/examples/client.go b/examples/client.go new file mode 100644 index 0000000..624a0d3 --- /dev/null +++ b/examples/client.go @@ -0,0 +1,72 @@ +package main + +import ( + "funstream/libs/socket.io" + "funstream/libs/socket.io/transport" + "log" + "runtime" + "time" +) + +type Channel struct { + Channel string `json:"channel"` +} + +type Message struct { + Id int `json:"id"` + Channel string `json:"channel"` + Text string `json:"text"` +} + +func sendJoin(c *gosocketio.Client) { + log.Println("Acking /join") + result, err := c.Ack("/join", Channel{"main"}, time.Second*5) + if err != nil { + log.Fatal(err) + } else { + log.Println("Ack result to /join: ", result) + } +} + +func main() { + runtime.GOMAXPROCS(runtime.NumCPU()) + + c, err := gosocketio.Dial("localhost:3811", transport.GetDefaultWebsocketTransport()) + if err != nil { + log.Fatal(err) + } + + err = c.On("/message", func(h *gosocketio.Channel, args Message) { + log.Println("--- Got chat message: ", args) + }) + if err != nil { + log.Fatal(err) + } + + err = c.On(gosocketio.OnDisconnection, func(h *gosocketio.Channel, args interface{}) { + log.Fatal("Disconnected") + }) + if err != nil { + log.Fatal(err) + } + + err = c.On(gosocketio.OnConnection, func(h *gosocketio.Channel, args interface{}) { + log.Println("Connected") + }) + if err != nil { + log.Fatal(err) + } + + time.Sleep(1 * time.Second) + + go sendJoin(c) + go sendJoin(c) + go sendJoin(c) + go sendJoin(c) + go sendJoin(c) + + time.Sleep(60 * time.Second) + c.Close() + + log.Println(" [x] Complete") +} diff --git a/examples/server.go b/examples/server.go new file mode 100644 index 0000000..4ea7474 --- /dev/null +++ b/examples/server.go @@ -0,0 +1,46 @@ +package main + +import ( + "funstream/libs/socket.io" + "funstream/libs/socket.io/transport" + "log" + "net/http" + "time" +) + +type Channel struct { + Channel string `json:"channel"` +} + +type Message struct { + Id int `json:"id"` + Channel string `json:"channel"` + Text string `json:"text"` +} + +func main() { + server := gosocketio.NewServer(transport.GetDefaultWebsocketTransport()) + + server.On(gosocketio.OnConnection, func(c *gosocketio.Channel, args interface{}) { + log.Println("Connected") + + c.Emit("/message", Message{10, "main", "using emit"}) + + c.Join("test") + c.BroadcastTo("test", "/message", Message{10, "main", "using broadcast"}) + }) + server.On(gosocketio.OnDisconnection, func(c *gosocketio.Channel, args interface{}) { + log.Println("Disconnected") + }) + + server.On("/join", func(c *gosocketio.Channel, channel Channel) string { + time.Sleep(2 * time.Second) + return "joined to " + channel.Channel + }) + + serveMux := http.NewServeMux() + serveMux.Handle("/socket.io/", server) + + log.Println("Starting server...") + log.Panic(http.ListenAndServe(":3811", serveMux)) +} diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..4fd1abe --- /dev/null +++ b/handler.go @@ -0,0 +1,128 @@ +package gosocketio + +import ( + "encoding/json" + "funstream/libs/socket.io/protocol" + "sync" +) + +const ( + OnConnection = "connection" + OnDisconnection = "disconnection" + OnError = "error" +) + +/** +System handler function for internal event processing +*/ +type systemHandler func(c *Channel) + +/** +Contains maps of message processing functions +*/ +type methods struct { + messageHandlers map[string]*caller + messageHandlersLock sync.RWMutex + + onConnection systemHandler + onDisconnection systemHandler +} + +/** +create messageHandlers map +*/ +func (m *methods) initMethods() { + m.messageHandlers = make(map[string]*caller) +} + +/** +Add message processing function, and bind it to given method +*/ +func (m *methods) On(method string, f interface{}) error { + c, err := newCaller(f) + if err != nil { + return err + } + + m.messageHandlersLock.Lock() + defer m.messageHandlersLock.Unlock() + m.messageHandlers[method] = c + + return nil +} + +/** +Find message processing function associated with given method +*/ +func (m *methods) findMethod(method string) (*caller, bool) { + m.messageHandlersLock.RLock() + defer m.messageHandlersLock.RUnlock() + + f, ok := m.messageHandlers[method] + return f, ok +} + +func (m *methods) callLoopEvent(c *Channel, event string) { + if m.onConnection != nil && event == OnConnection { + m.onConnection(c) + } + if m.onDisconnection != nil && event == OnDisconnection { + m.onDisconnection(c) + } + + f, ok := m.findMethod(event) + if !ok { + return + } + + f.callFunc(c, &struct{}{}) +} + +/** +Check incoming message +On ack_resp - look for waiter +On ack_req - look for processing function and send ack_resp +On emit - look for processing function +*/ +func (m *methods) processIncomingMessage(c *Channel, msg *protocol.Message) { + switch msg.Type { + case protocol.MessageTypeEmit: + f, ok := m.findMethod(msg.Method) + if !ok { + return + } + + data := f.getArgs() + err := json.Unmarshal([]byte(msg.Args), &data) + if err != nil { + return + } + + f.callFunc(c, data) + + case protocol.MessageTypeAckRequest: + f, ok := m.findMethod(msg.Method) + if !ok || !f.Out { + return + } + + data := f.getArgs() + err := json.Unmarshal([]byte(msg.Args), &data) + if err != nil { + return + } + + result := f.callFunc(c, data) + ack := &protocol.Message{ + Type: protocol.MessageTypeAckResponse, + AckId: msg.AckId, + } + send(ack, c, result[0].Interface()) + + case protocol.MessageTypeAckResponse: + waiter, err := c.ack.getWaiter(msg.AckId) + if err == nil { + waiter <- msg.Args + } + } +} diff --git a/loop.go b/loop.go new file mode 100644 index 0000000..d217b54 --- /dev/null +++ b/loop.go @@ -0,0 +1,185 @@ +package gosocketio + +import ( + "encoding/json" + "errors" + "funstream/libs/socket.io/protocol" + "funstream/libs/socket.io/transport" + "sync" + "time" +) + +const ( + queueBufferSize = 500 +) + +var ( + ErrorWrongHeader = errors.New("Wrong header") +) + +/** +engine.io header to send or receive +*/ +type Header struct { + Sid string `json:"sid"` + Upgrades []string `json:"upgrades"` + PingInterval int `json:"pingInterval"` + PingTimeout int `json:"pingTimeout"` +} + +/** +socket.io connection handler + +use IsAlive to check that handler is still working +use Dial to connect to websocket +use In and Out channels for message exchange +Close message means channel is closed +ping is automatic +*/ +type Channel struct { + conn transport.Connection + + out chan string + header Header + + alive bool + aliveLock sync.Mutex + + ack ackProcessor + + server *Server + ip string +} + +/** +create channel, map, and set active +*/ +func (c *Channel) initChannel() { + //TODO: queueBufferSize from constant to server or client variable + c.out = make(chan string, queueBufferSize) + c.ack.resultWaiters = make(map[int](chan string)) + c.alive = true +} + +/** +Get id of current socket connection +*/ +func (c *Channel) Id() string { + return c.header.Sid +} + +/** +Checks that Channel is still alive +*/ +func (c *Channel) IsAlive() bool { + return c.alive +} + +/** +Close channel +*/ +func CloseChannel(c *Channel, m *methods, args ...interface{}) error { + c.aliveLock.Lock() + defer c.aliveLock.Unlock() + + if !c.alive { + //already closed + return nil + } + + c.conn.Close() + c.alive = false + + //clean outloop + for len(c.out) > 0 { + <-c.out + } + c.out <- protocol.CloseMessage + + m.callLoopEvent(c, OnDisconnection) + return nil +} + +//incoming messages loop, puts incoming messages to In channel +func inLoop(c *Channel, m *methods) error { + for { + pkg, err := c.conn.GetMessage() + if err != nil { + return CloseChannel(c, m, err) + } + msg, err := protocol.Decode(pkg) + if err != nil { + CloseChannel(c, m, protocol.ErrorWrongPacket) + } + + switch msg.Type { + case protocol.MessageTypeOpen: + if err := json.Unmarshal([]byte(msg.Source[1:]), &c.header); err != nil { + CloseChannel(c, m, ErrorWrongHeader) + } + m.callLoopEvent(c, OnConnection) + case protocol.MessageTypePing: + c.out <- protocol.PongMessage + case protocol.MessageTypePong: + default: + go m.processIncomingMessage(c, msg) + } + } + return nil +} + +var overflooded map[*Channel]struct{} = make(map[*Channel]struct{}) +var overfloodedLock sync.Mutex + +func AmountOfOverflooded() int64 { + overfloodedLock.Lock() + defer overfloodedLock.Unlock() + + return int64(len(overflooded)) +} + +/** +outgoing messages loop, sends messages from channel to socket +*/ +func outLoop(c *Channel, m *methods) error { + for { + outBufferLen := len(c.out) + if outBufferLen == queueBufferSize { + return CloseChannel(c, m, ErrorSocketOverflood) + } else if outBufferLen > int(queueBufferSize/2) { + overfloodedLock.Lock() + overflooded[c] = struct{}{} + overfloodedLock.Unlock() + } else { + overfloodedLock.Lock() + delete(overflooded, c) + overfloodedLock.Unlock() + } + + msg := <-c.out + if msg == protocol.CloseMessage { + return nil + } + + err := c.conn.WriteMessage(msg) + if err != nil { + return CloseChannel(c, m, err) + } + } + return nil +} + +/** +Pinger sends ping messages for keeping connection alive +*/ +func pinger(c *Channel) { + for { + interval, _ := c.conn.PingParams() + time.Sleep(interval) + if !c.IsAlive() { + return + } + + c.out <- protocol.PingMessage + } +} diff --git a/protocol/message.go b/protocol/message.go new file mode 100644 index 0000000..3ba9d81 --- /dev/null +++ b/protocol/message.go @@ -0,0 +1,45 @@ +package protocol + +const ( + /** + Message with connection options + */ + MessageTypeOpen = iota + /** + Close connection and destroy all handle routines + */ + MessageTypeClose = iota + /** + Ping request message + */ + MessageTypePing = iota + /** + Pong response message + */ + MessageTypePong = iota + /** + Empty message + */ + MessageTypeEmpty = iota + /** + Emit request, no response + */ + MessageTypeEmit = iota + /** + Emit request, wait for response (ack) + */ + MessageTypeAckRequest = iota + /** + ack response + */ + MessageTypeAckResponse = iota +) + +type Message struct { + Type int + AckId int + Method string + Args string + Source string +} + diff --git a/protocol/socketio.go b/protocol/socketio.go new file mode 100644 index 0000000..ae8b81d --- /dev/null +++ b/protocol/socketio.go @@ -0,0 +1,214 @@ +package protocol + +import ( + "encoding/json" + "errors" + "strconv" + "strings" +) + +const ( + open = "0" + msg = "4" + emptyMessage = "40" + commonMessage = "42" + ackMessage = "43" + + CloseMessage = "1" + PingMessage = "2" + PongMessage = "3" +) + +var ( + ErrorWrongMessageType = errors.New("Wrong message type") + ErrorWrongPacket = errors.New("Wrong packet") +) + +func typeToText(msgType int) (string, error) { + switch msgType { + case MessageTypeOpen: + return open, nil + case MessageTypeClose: + return CloseMessage, nil + case MessageTypePing: + return PingMessage, nil + case MessageTypePong: + return PongMessage, nil + case MessageTypeEmpty: + return emptyMessage, nil + case MessageTypeEmit, MessageTypeAckRequest: + return commonMessage, nil + case MessageTypeAckResponse: + return ackMessage, nil + } + return "", ErrorWrongMessageType +} + +func Encode(msg *Message) (string, error) { + result, err := typeToText(msg.Type) + if err != nil { + return "", err + } + + if msg.Type == MessageTypeEmpty || msg.Type == MessageTypePing || + msg.Type == MessageTypePong { + return result, nil + } + + if msg.Type == MessageTypeAckRequest || msg.Type == MessageTypeAckResponse { + result += strconv.Itoa(msg.AckId) + } + + if msg.Type == MessageTypeOpen || msg.Type == MessageTypeClose { + return result + msg.Args, nil + } + + if msg.Type == MessageTypeAckResponse { + return result + "[" + msg.Args + "]", nil + } + + jsonMethod, err := json.Marshal(&msg.Method) + if err != nil { + return "", err + } + + return result + "[" + string(jsonMethod) + "," + msg.Args + "]", nil +} + +func MustEncode(msg *Message) string { + result, err := Encode(msg) + if err != nil { + panic(err) + } + + return result +} + +func getMessageType(data string) (int, error) { + if len(data) == 0 { + return 0, ErrorWrongMessageType + } + switch data[0:1] { + case open: + return MessageTypeOpen, nil + case CloseMessage: + return MessageTypeClose, nil + case PingMessage: + return MessageTypePing, nil + case PongMessage: + return MessageTypePong, nil + case msg: + if len(data) == 1 { + return 0, ErrorWrongMessageType + } + switch data[0:2] { + case emptyMessage: + return MessageTypeEmpty, nil + case commonMessage: + return MessageTypeAckRequest, nil + case ackMessage: + return MessageTypeAckResponse, nil + } + } + return 0, ErrorWrongMessageType +} + +/** +Get ack id of current packet, if present +*/ +func getAck(text string) (ackId int, restText string, err error) { + if len(text) < 4 { + return 0, "", ErrorWrongPacket + } + text = text[2:] + + pos := strings.IndexByte(text, '[') + if pos == -1 { + return 0, "", ErrorWrongPacket + } + + ack, err := strconv.Atoi(text[0:pos]) + if err != nil { + return 0, "", err + } + + return ack, text[pos:], nil +} + +/** +Get message method of current packet, if present +*/ +func getMethod(text string) (method, restText string, err error) { + var start, end, rest, countQuote int + + for i, c := range text { + if c == '"' { + switch countQuote { + case 0: + start = i + 1 + case 1: + end = i + rest = i + 1 + default: + return "", "", ErrorWrongPacket + } + countQuote++ + } + if c == ',' { + if countQuote < 2 { + continue + } + rest = i + 1 + break + } + } + + if (end < start) || (rest >= len(text)) { + return "", "", ErrorWrongPacket + } + + return text[start:end], text[rest : len(text)-1], nil +} + +func Decode(data string) (*Message, error) { + var err error + msg := &Message{} + msg.Source = data + + msg.Type, err = getMessageType(data) + if err != nil { + return nil, err + } + + if msg.Type == MessageTypeOpen { + msg.Args = data[1:] + return msg, nil + } + + if msg.Type == MessageTypeClose || msg.Type == MessageTypePing || + msg.Type == MessageTypePong || msg.Type == MessageTypeEmpty { + return msg, nil + } + + ack, rest, err := getAck(data) + msg.AckId = ack + if msg.Type == MessageTypeAckResponse { + if err != nil { + return nil, err + } + msg.Args = rest[1 : len(rest)-1] + return msg, nil + } + + if err != nil { + msg.Type = MessageTypeEmit + rest = data[2:] + } + + msg.Method, msg.Args, err = getMethod(rest) + if err != nil { + return nil, err + } + + return msg, nil +} diff --git a/send.go b/send.go new file mode 100644 index 0000000..663e99d --- /dev/null +++ b/send.go @@ -0,0 +1,77 @@ +package gosocketio + +import ( + "encoding/json" + "errors" + "time" + "funstream/libs/socket.io/protocol" +) + +var ( + ErrorSendTimeout = errors.New("Timeout") + ErrorSocketOverflood = errors.New("Socket overflood") +) + +/** +Send message packet to socket +*/ +func send(msg *protocol.Message, c *Channel, args interface{}) error { + json, err := json.Marshal(&args) + if err != nil { + return err + } + + msg.Args = string(json) + + command, err := protocol.Encode(msg) + if err != nil { + return err + } + + if len(c.out) == queueBufferSize { + return ErrorSocketOverflood + } + + c.out <- command + + return nil +} + +/** +Create packet based on given data and send it +*/ +func (c *Channel) Emit(method string, args interface{}) error { + msg := &protocol.Message{ + Type: protocol.MessageTypeEmit, + Method: method, + } + + return send(msg, c, args) +} + +/** +Create ack packet based on given data and send it and receive response +*/ +func (c *Channel) Ack(method string, args interface{}, timeout time.Duration) (string, error) { + msg := &protocol.Message{ + Type: protocol.MessageTypeAckRequest, + AckId: c.ack.getNextId(), + Method: method, + } + + waiter := make(chan string) + c.ack.addWaiter(msg.AckId, waiter) + + err := send(msg, c, args) + if err != nil { + c.ack.removeWaiter(msg.AckId) + } + + select { + case result := <-waiter: + return result, nil + case <-time.After(timeout): + c.ack.removeWaiter(msg.AckId) + return "", ErrorSendTimeout + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..0097965 --- /dev/null +++ b/server.go @@ -0,0 +1,325 @@ +package gosocketio + +import ( + "bytes" + "crypto/md5" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "funstream/libs/socket.io/transport" + "math/rand" + "net/http" + "sync" + "time" + "funstream/libs/socket.io/protocol" +) + +var ( + ErrorServerNotSet = errors.New("Server not set") + ErrorConnectionNotFound = errors.New("Connection not found") +) + +/** +socket.io server instance +*/ +type Server struct { + methods + http.Handler + + channels map[string]map[*Channel]struct{} + rooms map[*Channel]map[string]struct{} + channelsLock sync.RWMutex + + sids map[string]*Channel + sidsLock sync.RWMutex + + tr transport.Transport +} + +/** +Get channel by it's sid +*/ +func (s *Server) GetChannel(sid string) (*Channel, error) { + s.sidsLock.RLock() + defer s.sidsLock.RUnlock() + + c, ok := s.sids[sid] + if !ok { + return nil, ErrorConnectionNotFound + } + + return c, nil +} + +/** +Join this channel to given room +*/ +func (c *Channel) Join(room string) error { + if c.server == nil { + return ErrorServerNotSet + } + + c.server.channelsLock.Lock() + defer c.server.channelsLock.Unlock() + + cn := c.server.channels + if _, ok := cn[room]; !ok { + cn[room] = make(map[*Channel]struct{}) + } + + byRoom := c.server.rooms + if _, ok := byRoom[c]; !ok { + byRoom[c] = make(map[string]struct{}) + } + + cn[room][c] = struct{}{} + byRoom[c][room] = struct{}{} + + return nil +} + +/** +Remove this channel from given room +*/ +func (c *Channel) Leave(room string) error { + if c.server == nil { + return ErrorServerNotSet + } + + c.server.channelsLock.Lock() + defer c.server.channelsLock.Unlock() + + cn := c.server.channels + if _, ok := cn[room]; ok { + delete(cn[room], c) + if len(cn[room]) == 0 { + delete(cn, room) + } + } + + byRoom := c.server.rooms + if _, ok := byRoom[c]; ok { + delete(byRoom[c], room) + } + + return nil +} + +/** +Get list of channels, joined to given room, using channel +*/ +func (c *Channel) List(room string) ([]*Channel, error) { + if c.server == nil { + return nil, ErrorServerNotSet + } + + return c.server.List(room) +} + +/** +Get list of channels, joined to given room, using server +*/ +func (s *Server) List(room string) ([]*Channel, error) { + s.channelsLock.RLock() + defer s.channelsLock.RUnlock() + + roomChannels, ok := s.channels[room] + if !ok { + return []*Channel{}, nil + } + + i := 0 + roomChannelsCopy := make([]*Channel, len(roomChannels)) + for channel := range roomChannels { + roomChannelsCopy[i] = channel + i++ + } + + return roomChannelsCopy, nil + +} + +func (c *Channel) BroadcastTo(room, method string, args interface{}) { + if c.server == nil { + return + } + c.server.BroadcastTo(room, method, args) +} + +/** +Broadcast message to all room channels +*/ +func (s *Server) BroadcastTo(room, method string, args interface{}) { + s.channelsLock.RLock() + defer s.channelsLock.RUnlock() + + roomChannels, ok := s.channels[room] + if !ok { + return + } + + for cn := range roomChannels { + if cn.IsAlive() { + go cn.Emit(method, args) + } + } +} + +/** +Broadcast to all clients +*/ +func (s *Server) BroadcastToAll(method string, args interface{}) { + s.sidsLock.RLock() + defer s.sidsLock.RUnlock() + + for _, cn := range s.sids { + if cn.IsAlive() { + go cn.Emit(method, args) + } + } +} + +/** +Generate new id for socket.io connection +*/ +func generateNewId(custom string) string { + hash := fmt.Sprintf("%s %s %n %n", custom, time.Now(), rand.Uint32(), rand.Uint32()) + buf := bytes.NewBuffer(nil) + sum := md5.Sum([]byte(hash)) + encoder := base64.NewEncoder(base64.URLEncoding, buf) + encoder.Write(sum[:]) + encoder.Close() + return buf.String()[:20] +} + +/** +On connection system handler, store sid +*/ +func onConnectStore(c *Channel) { + c.server.sidsLock.Lock() + defer c.server.sidsLock.Unlock() + + c.server.sids[c.Id()] = c +} + +/** +On disconnection system handler, clean joins and sid +*/ +func onDisconnectCleanup(c *Channel) { + c.server.channelsLock.Lock() + defer c.server.channelsLock.Unlock() + + cn := c.server.channels + byRoom, ok := c.server.rooms[c] + if ok { + for room := range byRoom { + if curRoom, ok := cn[room]; ok { + delete(curRoom, c) + if len(curRoom) == 0 { + delete(cn, room) + } + } + } + + delete(c.server.rooms, c) + } + + c.server.sidsLock.Lock() + defer c.server.sidsLock.Unlock() + + delete(c.server.sids, c.Id()) +} + +func (s *Server) SendOpenSequence(c *Channel) { + jsonHdr, err := json.Marshal(&c.header) + if err != nil { + panic(err) + } + + c.out <- protocol.MustEncode( + &protocol.Message{ + Type: protocol.MessageTypeOpen, + Args: string(jsonHdr), + }, + ) + + c.out <- protocol.MustEncode(&protocol.Message{Type: protocol.MessageTypeEmpty}) +} + +/** +Setup event loop for given connection +*/ +func (s *Server) SetupEventLoop(conn transport.Connection, remoteAddr string) { + interval, timeout := conn.PingParams() + hdr := Header{ + Sid: generateNewId(remoteAddr), + Upgrades: []string{}, + PingInterval: int(interval / time.Millisecond), + PingTimeout: int(timeout / time.Millisecond), + } + + c := &Channel{} + c.conn = conn + c.ip = remoteAddr + c.initChannel() + + c.server = s + c.header = hdr + + s.SendOpenSequence(c) + + go inLoop(c, &s.methods) + go outLoop(c, &s.methods) + + s.callLoopEvent(c, OnConnection) +} + +/** +implements ServeHTTP function from http.Handler +*/ +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + conn, err := s.tr.HandleConnection(w, r) + if err != nil { + return + } + + s.SetupEventLoop(conn, r.RemoteAddr) + s.tr.Serve(w, r) +} + +/** +Get amount of current connected sids +*/ +func (s *Server) AmountOfSids() int64 { + s.sidsLock.RLock() + defer s.sidsLock.RUnlock() + + return int64(len(s.sids)) +} + +/** +Get amount of rooms with at least one channel(or sid) joined +*/ +func (s *Server) AmountOfRooms() int64 { + s.channelsLock.RLock() + defer s.channelsLock.RUnlock() + + return int64(len(s.channels)) +} + +/** +Create new socket.io server +*/ +func NewServer(tr transport.Transport) *Server { + s := Server{} + s.initMethods() + s.tr = tr + s.channels = make(map[string]map[*Channel]struct{}) + s.rooms = make(map[*Channel]map[string]struct{}) + s.sids = make(map[string]*Channel) + s.onConnection = onConnectStore + s.onDisconnection = onDisconnectCleanup + + return &s +} diff --git a/transport/transport.go b/transport/transport.go new file mode 100644 index 0000000..fe7a4a3 --- /dev/null +++ b/transport/transport.go @@ -0,0 +1,51 @@ +package transport + +import ( + "net/http" + "time" +) + +/** +End-point connection for given transport +*/ +type Connection interface { + /** + Receive one more message, block until received + */ + GetMessage() (message string, err error) + + /** + Send given message, block until sent + */ + WriteMessage(message string) error + + /** + Close current connection + */ + Close() + + /** + Get ping time interval and ping request timeout + */ + PingParams() (interval, timeout time.Duration) +} + +/** +Connection factory for given transport +*/ +type Transport interface { + /** + Get client connection + */ + Connect(host string) (conn Connection, err error) + + /** + Handle one server connection + */ + HandleConnection(w http.ResponseWriter, r *http.Request) (conn Connection, err error) + + /** + Serve HTTP request after making connection and events setup + */ + Serve(w http.ResponseWriter, r *http.Request) +} diff --git a/transport/websocket.go b/transport/websocket.go new file mode 100644 index 0000000..9480777 --- /dev/null +++ b/transport/websocket.go @@ -0,0 +1,138 @@ +package transport + +import ( + "errors" + "github.com/gorilla/websocket" + "io/ioutil" + "net/http" + "time" +) + +const ( + webSocketProtocol = "ws://" + socketioUrl = "/socket.io/?EIO=3&transport=websocket" + upgradeFailed = "Upgrade failed: " + + WsDefaultPingInterval = 30 * time.Second + WsDefaultPingTimeout = 60 * time.Second + WsDefaultReceiveTimeout = 60 * time.Second + WsDefaultSendTimeout = 60 * time.Second + WsDefaultBufferSize = 1024 * 32 +) + +var ( + ErrorBinaryMessage = errors.New("Binary messages are not supported") + ErrorBadBuffer = errors.New("Buffer error") + ErrorPacketWrong = errors.New("Wrong packet type error") + ErrorMethodNotAllowed = errors.New("Method not allowed") + ErrorHttpUpgradeFailed = errors.New("Http upgrade failed") +) + +type WebsocketConnection struct { + socket *websocket.Conn + transport *WebsocketTransport +} + +func (wsc *WebsocketConnection) GetMessage() (message string, err error) { + wsc.socket.SetReadDeadline(time.Now().Add(wsc.transport.ReceiveTimeout)) + msgType, reader, err := wsc.socket.NextReader() + if err != nil { + return "", err + } + + //support only text messages exchange + if msgType != websocket.TextMessage { + return "", ErrorBinaryMessage + } + + data, err := ioutil.ReadAll(reader) + if err != nil { + return "", ErrorBadBuffer + } + text := string(data) + + //empty messages are not allowed + if len(text) == 0 { + return "", ErrorPacketWrong + } + + return text, nil +} + +func (wsc *WebsocketConnection) WriteMessage(message string) error { + wsc.socket.SetWriteDeadline(time.Now().Add(wsc.transport.SendTimeout)) + writer, err := wsc.socket.NextWriter(websocket.TextMessage) + if err != nil { + return err + } + + if _, err := writer.Write([]byte(message)); err != nil { + return err + } + if err := writer.Close(); err != nil { + return err + } + return nil +} + +func (wsc *WebsocketConnection) Close() { + wsc.socket.Close() +} + +func (wsc *WebsocketConnection) PingParams() (interval, timeout time.Duration) { + return wsc.transport.PingInterval, wsc.transport.PingTimeout +} + +type WebsocketTransport struct { + PingInterval time.Duration + PingTimeout time.Duration + ReceiveTimeout time.Duration + SendTimeout time.Duration + + BufferSize int +} + +func (wst *WebsocketTransport) Connect(host string) (conn Connection, err error) { + dialer := websocket.Dialer{} + socket, _, err := dialer.Dial(webSocketProtocol+host+socketioUrl, nil) + if err != nil { + return nil, err + } + + return &WebsocketConnection{socket, wst}, nil +} + +func (wst *WebsocketTransport) HandleConnection( + w http.ResponseWriter, r *http.Request) (conn Connection, err error) { + + if r.Method != "GET" { + http.Error(w, upgradeFailed+ErrorMethodNotAllowed.Error(), 503) + return nil, ErrorMethodNotAllowed + } + + socket, err := websocket.Upgrade(w, r, nil, wst.BufferSize, wst.BufferSize) + if err != nil { + http.Error(w, upgradeFailed+err.Error(), 503) + return nil, ErrorHttpUpgradeFailed + } + + return &WebsocketConnection{socket, wst}, nil +} + +/** +Websocket connection do not require any additional processing +*/ +func (wst *WebsocketTransport) Serve(w http.ResponseWriter, r *http.Request) {} + +/** +Returns websocket connection with default params +*/ +func GetDefaultWebsocketTransport() *WebsocketTransport { + return &WebsocketTransport{ + PingInterval: WsDefaultPingInterval, + PingTimeout: WsDefaultPingTimeout, + ReceiveTimeout: WsDefaultReceiveTimeout, + SendTimeout: WsDefaultSendTimeout, + BufferSize: WsDefaultBufferSize, + } +}