diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 1ed484ecc..4b4771cb6 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -170,6 +170,8 @@ type rawConnection struct { closeOnce sync.Once sendCloseOnce sync.Once compression Compression + + loopWG sync.WaitGroup // Need to ensure no leftover routines in testing } type asyncResult struct { @@ -244,20 +246,35 @@ func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, rec dispatcherLoopStopped: make(chan struct{}), closed: make(chan struct{}), compression: compress, + loopWG: sync.WaitGroup{}, } } // Start creates the goroutines for sending and receiving of messages. It must // be called exactly once after creating a connection. func (c *rawConnection) Start() { - go c.readerLoop() + c.loopWG.Add(5) + go func() { + c.readerLoop() + c.loopWG.Done() + }() go func() { err := c.dispatcherLoop() - c.internalClose(err) + c.Close(err) + c.loopWG.Done() + }() + go func() { + c.writerLoop() + c.loopWG.Done() + }() + go func() { + c.pingSender() + c.loopWG.Done() + }() + go func() { + c.pingReceiver() + c.loopWG.Done() }() - go c.writerLoop() - go c.pingSender() - go c.pingReceiver() c.startTime = time.Now() } @@ -410,7 +427,7 @@ func (c *rawConnection) dispatcherLoop() (err error) { state = stateReady } if err := c.receiver.ClusterConfig(c.id, *msg); err != nil { - return errors.Wrap(err, "receiver error") + return fmt.Errorf("receiving cluster config: %w", err) } case *Index: @@ -422,7 +439,7 @@ func (c *rawConnection) dispatcherLoop() (err error) { return errors.Wrap(err, "protocol error: index") } if err := c.handleIndex(*msg); err != nil { - return errors.Wrap(err, "receiver error") + return fmt.Errorf("receiving index: %w", err) } state = stateReady @@ -435,7 +452,7 @@ func (c *rawConnection) dispatcherLoop() (err error) { return errors.Wrap(err, "protocol error: index update") } if err := c.handleIndexUpdate(*msg); err != nil { - return errors.Wrap(err, "receiver error") + return fmt.Errorf("receiving index update: %w", err) } state = stateReady @@ -462,7 +479,7 @@ func (c *rawConnection) dispatcherLoop() (err error) { return fmt.Errorf("protocol error: response message in state %d", state) } if err := c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates); err != nil { - return errors.Wrap(err, "receiver error") + return fmt.Errorf("receiving download progress: %w", err) } case *Ping: @@ -474,7 +491,7 @@ func (c *rawConnection) dispatcherLoop() (err error) { case *Close: l.Debugln("read Close message") - return errors.New(msg.Reason) + return fmt.Errorf("closed by remote: %v", msg.Reason) default: l.Debugf("read unknown message: %+T", msg) diff --git a/lib/protocol/protocol_test.go b/lib/protocol/protocol_test.go index 398ac8987..ec56dbcfc 100644 --- a/lib/protocol/protocol_test.go +++ b/lib/protocol/protocol_test.go @@ -33,8 +33,10 @@ func TestPing(t *testing.T) { c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c0.Start() + defer closeAndWait(c0, ar, bw) c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c1.Start() + defer closeAndWait(c1, ar, bw) c0.ClusterConfig(ClusterConfig{}) c1.ClusterConfig(ClusterConfig{}) @@ -57,8 +59,10 @@ func TestClose(t *testing.T) { c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c0.Start() + defer closeAndWait(c0, ar, bw) c1 := NewConnection(c1ID, br, aw, m1, "name", CompressionAlways) c1.Start() + defer closeAndWait(c1, ar, bw) c0.ClusterConfig(ClusterConfig{}) c1.ClusterConfig(ClusterConfig{}) @@ -97,8 +101,10 @@ func TestCloseOnBlockingSend(t *testing.T) { m := newTestModel() - c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + rw := testutils.NewBlockingRW() + c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c.Start() + defer closeAndWait(c, rw) wg := sync.WaitGroup{} @@ -149,8 +155,10 @@ func TestCloseRace(t *testing.T) { c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressionNever).(wireFormatConnection).Connection.(*rawConnection) c0.Start() + defer closeAndWait(c0, ar, bw) c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressionNever) c1.Start() + defer closeAndWait(c1, ar, bw) c0.ClusterConfig(ClusterConfig{}) c1.ClusterConfig(ClusterConfig{}) @@ -184,8 +192,10 @@ func TestCloseRace(t *testing.T) { func TestClusterConfigFirst(t *testing.T) { m := newTestModel() - c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + rw := testutils.NewBlockingRW() + c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c.Start() + defer closeAndWait(c, rw) select { case c.outbox <- asyncMessage{&Ping{}, nil}: @@ -234,8 +244,10 @@ func TestCloseTimeout(t *testing.T) { m := newTestModel() - c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + rw := testutils.NewBlockingRW() + c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c.Start() + defer closeAndWait(c, rw) done := make(chan struct{}) go func() { @@ -852,8 +864,10 @@ func TestSha256OfEmptyBlock(t *testing.T) { func TestClusterConfigAfterClose(t *testing.T) { m := newTestModel() - c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + rw := testutils.NewBlockingRW() + c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) c.Start() + defer closeAndWait(c, rw) c.internalClose(errManual) @@ -874,11 +888,13 @@ func TestDispatcherToCloseDeadlock(t *testing.T) { // Verify that we don't deadlock when calling Close() from within one of // the model callbacks (ClusterConfig). m := newTestModel() - c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) + rw := testutils.NewBlockingRW() + c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection) m.ccFn = func(devID DeviceID, cc ClusterConfig) { c.Close(errManual) } c.Start() + defer closeAndWait(c, rw) c.inbox <- &ClusterConfig{} @@ -945,3 +961,18 @@ func TestIndexIDString(t *testing.T) { t.Error(i.String()) } } + +func closeAndWait(c Connection, closers ...io.Closer) { + for _, closer := range closers { + closer.Close() + } + var raw *rawConnection + switch i := c.(type) { + case wireFormatConnection: + raw = i.Connection.(*rawConnection) + case *rawConnection: + raw = i + } + raw.internalClose(ErrClosed) + raw.loopWG.Wait() +} diff --git a/lib/testutils/testutils.go b/lib/testutils/testutils.go index cc5b1f974..77f420644 100644 --- a/lib/testutils/testutils.go +++ b/lib/testutils/testutils.go @@ -6,17 +6,40 @@ package testutils -// BlockingRW implements io.Reader and Writer but never returns when called -type BlockingRW struct{ nilChan chan struct{} } +import ( + "errors" + "sync" +) -func (rw *BlockingRW) Read(p []byte) (n int, err error) { - <-rw.nilChan - return +var ErrClosed = errors.New("closed") + +// BlockingRW implements io.Reader, Writer and Closer, but only returns when closed +type BlockingRW struct { + c chan struct{} + closeOnce sync.Once } -func (rw *BlockingRW) Write(p []byte) (n int, err error) { - <-rw.nilChan - return +func NewBlockingRW() *BlockingRW { + return &BlockingRW{ + c: make(chan struct{}), + closeOnce: sync.Once{}, + } +} +func (rw *BlockingRW) Read(p []byte) (int, error) { + <-rw.c + return 0, ErrClosed +} + +func (rw *BlockingRW) Write(p []byte) (int, error) { + <-rw.c + return 0, ErrClosed +} + +func (rw *BlockingRW) Close() error { + rw.closeOnce.Do(func() { + close(rw.c) + }) + return nil } // NoopRW implements io.Reader and Writer but never returns when called