Skip to content
Snippets Groups Projects
ws.go 2.69 KiB
Newer Older
  • Learn to ignore specific revisions
  • package main
    
    import (
    	"net/http"
    	"sync"
    	"time"
    
    	"github.com/hashicorp/go-multierror"
    	"github.com/lesismal/nbio/nbhttp/websocket"
    	"github.com/sirupsen/logrus"
    )
    
    // This implementation follows the code example at:
    // https://github.com/lesismal/nbio/issues/92#issuecomment-922183823
    
    var (
    	clientManager  = NewClientManager()
    	keepaliveTimer = time.Second * 60
    )
    
    // ClientManager holds all the current active websocket connections as so
    // called clients.
    type ClientManager struct {
    	mux     sync.Mutex
    	clients map[*websocket.Conn]struct{}
    }
    
    // Register adds a new client (a websocket connection) to the ClientManager.
    func (cMngr *ClientManager) Register(client *websocket.Conn) {
    	cMngr.mux.Lock()
    	defer cMngr.mux.Unlock()
    
    	cMngr.clients[client] = struct{}{}
    	logrus.Println("Added new client: ", client.RemoteAddr())
    }
    
    // Deregister removes a client (a websocket connection) from the ClientManager.
    func (cMngr *ClientManager) Deregister(client *websocket.Conn) {
    	cMngr.mux.Lock()
    	defer cMngr.mux.Unlock()
    
    	delete(cMngr.clients, client)
    	logrus.Println("Removed client: ", client.RemoteAddr())
    }
    
    // Publish sends the given byte slice to all the clients (a client is a
    // websocket connection).
    func (cMngr *ClientManager) Publish(message []byte) {
    	cMngr.mux.Lock()
    	defer cMngr.mux.Unlock()
    
    	var eg multierror.Group
    	for c := range cMngr.clients {
    		eg.Go(func() error {
    			return c.WriteMessage(websocket.TextMessage, message)
    		})
    	}
    	if err := eg.Wait(); err != nil {
    		logrus.Printf("Publish encountered errors while broadcasting: %v\n", err)
    	}
    }
    
    // NewClientManager returns a new websocket ClientManager.
    func NewClientManager() *ClientManager {
    	return &ClientManager{
    		clients: make(map[*websocket.Conn]struct{}),
    	}
    }
    
    // newUpgrader creates a new websocket.Upgrader and implements the OnOpen,
    // OnClose and OnMessage functions.
    func newUpgrader() *websocket.Upgrader {
    	u := websocket.NewUpgrader()
    
    	u.OnOpen(func(c *websocket.Conn) {
    		clientManager.Register(c)
    	})
    
    	u.OnClose(func(c *websocket.Conn, err error) {
    		clientManager.Deregister(c)
    	})
    
    	u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) {
    		if err := c.SetReadDeadline(time.Now().Add(keepaliveTimer)); err != nil {
    			logrus.Printf("Could not update ReadDeadline: %v\n", err)
    		}
    	})
    
    	return u
    }
    
    // onWebsocket is the handling function for a new websocket.
    func onWebsocket(w http.ResponseWriter, r *http.Request) {
    	upgrader := newUpgrader()
    	conn, err := upgrader.Upgrade(w, r, nil)
    	if err != nil {
    		logrus.Println("Could not create websocket")
    		return
    	}
    
    	if err := conn.SetReadDeadline(time.Now().Add(keepaliveTimer)); err != nil {
    		logrus.Printf("Could not set initial ReadDeadline: %v\n", err)
    		return
    	}
    }