Skip to content
Snippets Groups Projects
ws.go 1.94 KiB
Newer Older
  • Learn to ignore specific revisions
  • 	"github.com/hashicorp/go-multierror"
    
    	"github.com/lesismal/nbio/nbhttp/websocket"
    	"github.com/sirupsen/logrus"
    )
    
    
    // This implementation follows the code example at:
    // https://github.com/lesismal/nbio/issues/92#issuecomment-922183823
    
    
    var (
    	clientManager  = NewClientManager()
    
    Malte Bauch's avatar
    Malte Bauch committed
    	keepaliveTimer = time.Second * 60
    
    	clients map[*websocket.Conn]struct{}
    }
    
    func (cMngr *ClientManager) Register(client *websocket.Conn) {
    
    	cMngr.mux.Lock()
    	defer cMngr.mux.Unlock()
    
    
    	cMngr.clients[client] = struct{}{}
    	logrus.Println("Added new client: ", client.RemoteAddr())
    }
    func (cMngr *ClientManager) Deregister(client *websocket.Conn) {
    
    	cMngr.mux.Lock()
    	defer cMngr.mux.Unlock()
    
    
    	delete(cMngr.clients, client)
    	logrus.Println("Removed client: ", client.RemoteAddr())
    }
    func (cMngr *ClientManager) Publish(message []byte) {
    
    	cMngr.mux.Lock()
    	defer cMngr.mux.Unlock()
    
    	var eg multierror.Group
    
    	for c := range cMngr.clients {
    		eg.Go(func() error {
    			return c.WriteMessage(websocket.TextMessage, message)
    		})
    
    	}
    	if err := eg.Wait(); err != nil {
    		logrus.Printf("Publish encountered errors while broadcasting: %v\n", err)
    
    func NewClientManager() *ClientManager {
    	return &ClientManager{
    		clients: make(map[*websocket.Conn]struct{}),
    	}
    }
    
    func newUpgrader() *websocket.Upgrader {
    	u := websocket.NewUpgrader()
    
    	u.OnOpen(func(c *websocket.Conn) {
    		clientManager.Register(c)
    	})
    
    	u.OnClose(func(c *websocket.Conn, err error) {
    		clientManager.Deregister(c)
    	})
    
    	u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) {
    		c.SetReadDeadline(time.Now().Add(keepaliveTimer))
    	})
    
    	return u
    }
    
    func onWebsocket(w http.ResponseWriter, r *http.Request) {
    	upgrader := newUpgrader()
    	conn, err := upgrader.Upgrade(w, r, nil)
    	if err != nil {
    		logrus.Println("Could not create websocket")
    		return
    	}
    
    	conn.SetReadDeadline(time.Now().Add(keepaliveTimer))
    }