Skip to content

Commit

Permalink
improve session close synchronization (northwesternmutual#18)
Browse files Browse the repository at this point in the history
* improve session close synchronization

* CR
  • Loading branch information
AdallomRoy committed Jul 21, 2021
1 parent e1604d9 commit 87ea36d
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 14 deletions.
2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package grammes

import (
"golang.org/x/sync/errgroup"
"sync"
"time"

Expand Down Expand Up @@ -72,6 +73,7 @@ type Client struct {
// requestTimeout is used for time-outing requests that a response is not received for
requestTimeout time.Duration
requestSemaphore *semaphore.Weighted
commRoutines errgroup.Group
}

// setupClient default values some fields in the client.
Expand Down
18 changes: 14 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ package grammes

import (
"errors"

"github.com/northwesternmutual/grammes/gremconnect"
"github.com/northwesternmutual/grammes/gremerror"
)
Expand All @@ -43,9 +42,18 @@ func (c *Client) launchConnection() error {
quit := c.conn.GetQuit()

// Launch processes to keep track of connection & data
go c.writeWorker(c.err, quit) // Initiates message writing to the Gremlin-server
go c.readWorker(c.err, quit) // Initiates message reading from the Gremlin-server
go c.conn.Ping(c.err) // Manages pinging and connection to the Gremlin-server
c.commRoutines.Go(func() error {
c.writeWorker(c.err, quit) // Initiates message writing to the Gremlin-server
return nil
})
c.commRoutines.Go(func() error {
c.readWorker(c.err, quit) // Initiates message reading from the Gremlin-server
return nil
})
c.commRoutines.Go(func() error {
c.conn.Ping(c.err) // Manages pinging and connection to the Gremlin-server
return nil
})

return nil
}
Expand All @@ -55,6 +63,8 @@ func (c *Client) Close() {
if c.conn != nil {
c.conn.Close()
}

c.commRoutines.Wait()
}

// IsConnected returns if the client currently
Expand Down
2 changes: 1 addition & 1 deletion gremconnect/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func (ws *WebSocket) Read() (msg []byte, err error) {
func (ws *WebSocket) Close() error {
defer func() {
close(ws.Quit) // close the channel to notify our pinger.
ws.conn.Close()
ws.disposed = true
ws.conn.Close()
}()

// Send the server the message that we've closed
Expand Down
4 changes: 2 additions & 2 deletions manager/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ type SessionQuerier interface {
NewSession() Session
NewNoopSession() Session
GetSession(uuid.UUID) Session
WithSession(Session, func(Session) error) error
WithNewSession(f func(Session) error) error
WithSession(Session, func(Session) error, bool) error
WithNewSession(func(Session) error, bool) error
}

// VertexQuerier handles the vertices on the graph.
Expand Down
14 changes: 8 additions & 6 deletions manager/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,18 @@ func (s *sessionManager) GetSession(sessionId uuid.UUID) Session {
}
}

func (s *sessionManager) WithNewSession(f func(Session) error) error {
return s.WithSession(s.NewSession(), f)
func (s *sessionManager) WithNewSession(f func(Session) error, closeOnDone bool) error {
return s.WithSession(s.NewSession(), f, closeOnDone)
}

func (s *sessionManager) WithSession(ss Session, f func(Session) error) error {
func (s *sessionManager) WithSession(ss Session, f func(Session) error, closeOnDone bool) error {
var err error
defer func() {
err2 := ss.Close()
if err == nil { // Capture close error if everything else was ok
err = err2
if closeOnDone {
err2 := ss.Close()
if err == nil { // Capture close error if everything else was ok
err = err2
}
}
}()

Expand Down
5 changes: 4 additions & 1 deletion response.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ func (c *Client) readWorker(errs chan error, quit chan struct{}) {
// attempt to read from the connection
// and store the message back into a variable.
if msg, err = c.conn.Read(); err != nil {
errs <- err
if !c.conn.IsDisposed() { // When disposing a connection, gorilla will return an error
errs <- err
}

c.broken = true
break
}
Expand Down

0 comments on commit 87ea36d

Please sign in to comment.