diff --git a/lib/connections/service.go b/lib/connections/service.go index 01480defd..96caf949f 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -202,7 +202,6 @@ next: // Lower priority is better, just like nice etc. if priorityKnown && ct.Priority > c.Priority { l.Debugln("Switching connections", remoteID) - s.model.Close(remoteID, protocol.ErrSwitchingConnections) } else if connected { // We should not already be connected to the other party. TODO: This // could use some better handling. If the old connection is dead but diff --git a/lib/connections/structs.go b/lib/connections/structs.go index 65d2fee48..53c432afb 100644 --- a/lib/connections/structs.go +++ b/lib/connections/structs.go @@ -8,6 +8,7 @@ package connections import ( "crypto/tls" + "fmt" "net" "net/url" "time" @@ -28,6 +29,10 @@ type Connection struct { protocol.Connection } +func (c Connection) String() string { + return fmt.Sprintf("%s-%s/%s", c.LocalAddr(), c.RemoteAddr(), c.Type) +} + type dialerFactory interface { New(*config.Wrapper, *tls.Config) genericDialer Priority() int diff --git a/lib/model/model.go b/lib/model/model.go index b887b8099..fbaf18a8f 100644 --- a/lib/model/model.go +++ b/lib/model/model.go @@ -94,6 +94,7 @@ type Model struct { fmut sync.RWMutex // protects the above conn map[protocol.DeviceID]connections.Connection + closed map[protocol.DeviceID]chan struct{} helloMessages map[protocol.DeviceID]protocol.HelloResult devicePaused map[protocol.DeviceID]bool deviceDownloads map[protocol.DeviceID]*deviceDownloadState @@ -152,6 +153,7 @@ func NewModel(cfg *config.Wrapper, id protocol.DeviceID, deviceName, clientName, folderRunnerTokens: make(map[string][]suture.ServiceToken), folderStatRefs: make(map[string]*stats.FolderStatisticsReference), conn: make(map[protocol.DeviceID]connections.Connection), + closed: make(map[protocol.DeviceID]chan struct{}), helloMessages: make(map[protocol.DeviceID]protocol.HelloResult), devicePaused: make(map[protocol.DeviceID]bool), deviceDownloads: make(map[protocol.DeviceID]*deviceDownloadState), @@ -912,25 +914,42 @@ func (m *Model) ClusterConfig(deviceID protocol.DeviceID, cm protocol.ClusterCon } } -// Close removes the peer from the model and closes the underlying connection if possible. -// Implements the protocol.Model interface. -func (m *Model) Close(device protocol.DeviceID, err error) { - l.Infof("Connection to %s closed: %v", device, err) - events.Default.Log(events.DeviceDisconnected, map[string]string{ - "id": device.String(), - "error": err.Error(), - }) +// Closed is called when a connection has been closed +func (m *Model) Closed(conn protocol.Connection, err error) { + device := conn.ID() m.pmut.Lock() conn, ok := m.conn[device] if ok { m.progressEmitter.temporaryIndexUnsubscribe(conn) - closeRawConn(conn) } delete(m.conn, device) delete(m.helloMessages, device) delete(m.deviceDownloads, device) + closed := m.closed[device] + delete(m.closed, device) m.pmut.Unlock() + + l.Infof("Connection to %s closed: %v", device, err) + events.Default.Log(events.DeviceDisconnected, map[string]string{ + "id": device.String(), + "error": err.Error(), + }) + close(closed) +} + +// close will close the underlying connection for a given device +func (m *Model) close(device protocol.DeviceID) { + m.pmut.Lock() + conn, ok := m.conn[device] + m.pmut.Unlock() + + if !ok { + // There is no connection to close + return + } + + closeRawConn(conn) } // Request returns the specified data segment by reading it from local disk. @@ -1171,10 +1190,22 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR deviceID := conn.ID() m.pmut.Lock() - if _, ok := m.conn[deviceID]; ok { - panic("add existing device") + if oldConn, ok := m.conn[deviceID]; ok { + l.Infoln("Replacing old connection", oldConn, "with", conn, "for", deviceID) + // There is an existing connection to this device that we are + // replacing. We must close the existing connection and wait for the + // close to complete before adding the new connection. We do the + // actual close without holding pmut as the connection will call + // back into Closed() for the cleanup. + closed := m.closed[deviceID] + m.pmut.Unlock() + closeRawConn(oldConn) + <-closed + m.pmut.Lock() } + m.conn[deviceID] = conn + m.closed[deviceID] = make(chan struct{}) m.deviceDownloads[deviceID] = newDeviceDownloadState() m.helloMessages[deviceID] = hello @@ -1215,10 +1246,10 @@ func (m *Model) AddConnection(conn connections.Connection, hello protocol.HelloR func (m *Model) PauseDevice(device protocol.DeviceID) { m.pmut.Lock() m.devicePaused[device] = true - _, ok := m.conn[device] + conn, ok := m.conn[device] m.pmut.Unlock() if ok { - m.Close(device, errors.New("device paused")) + closeRawConn(conn) } events.Default.Log(events.DevicePaused, map[string]string{"device": device.String()}) } diff --git a/lib/model/model_test.go b/lib/model/model_test.go index 793162785..c6c89643b 100644 --- a/lib/model/model_test.go +++ b/lib/model/model_test.go @@ -351,7 +351,7 @@ func TestDeviceRename(t *testing.T) { t.Errorf("Device already has a name") } - m.Close(device1, protocol.ErrTimeout) + m.Closed(conn, protocol.ErrTimeout) hello.DeviceName = "tester" m.AddConnection(conn, hello) @@ -359,7 +359,7 @@ func TestDeviceRename(t *testing.T) { t.Errorf("Device did not get a name") } - m.Close(device1, protocol.ErrTimeout) + m.Closed(conn, protocol.ErrTimeout) hello.DeviceName = "tester2" m.AddConnection(conn, hello) @@ -376,7 +376,7 @@ func TestDeviceRename(t *testing.T) { t.Errorf("Device name not saved in config") } - m.Close(device1, protocol.ErrTimeout) + m.Closed(conn, protocol.ErrTimeout) opts := cfg.Options() opts.OverwriteRemoteDevNames = true @@ -1527,7 +1527,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { m.StartFolder(fcfg.ID) m.ServeBackground() - m.AddConnection(connections.Connection{ + conn1 := connections.Connection{ IntermediateConnection: connections.IntermediateConnection{ Conn: tls.Client(&fakeConn{}, nil), Type: "foo", @@ -1536,8 +1536,9 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { Connection: &FakeConnection{ id: device1, }, - }, protocol.HelloResult{}) - m.AddConnection(connections.Connection{ + } + m.AddConnection(conn1, protocol.HelloResult{}) + conn2 := connections.Connection{ IntermediateConnection: connections.IntermediateConnection{ Conn: tls.Client(d2c, nil), Type: "foo", @@ -1546,7 +1547,8 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { Connection: &FakeConnection{ id: device2, }, - }, protocol.HelloResult{}) + } + m.AddConnection(conn2, protocol.HelloResult{}) m.ClusterConfig(device1, protocol.ClusterConfig{ Folders: []protocol.Folder{ @@ -1629,7 +1631,7 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { t.Error("downloads missing early") } - m.Close(device2, fmt.Errorf("foo")) + m.Closed(conn2, fmt.Errorf("foo")) if _, ok := m.conn[device2]; ok { t.Error("conn not missing") diff --git a/lib/protocol/benchmark_test.go b/lib/protocol/benchmark_test.go index 7a0049ada..c10e97d2d 100644 --- a/lib/protocol/benchmark_test.go +++ b/lib/protocol/benchmark_test.go @@ -181,7 +181,7 @@ func (m *fakeModel) Request(deviceID DeviceID, folder string, name string, offse func (m *fakeModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) { } -func (m *fakeModel) Close(deviceID DeviceID, err error) { +func (m *fakeModel) Closed(conn Connection, err error) { } func (m *fakeModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) { diff --git a/lib/protocol/common_test.go b/lib/protocol/common_test.go index 75356be55..6fdee7a1b 100644 --- a/lib/protocol/common_test.go +++ b/lib/protocol/common_test.go @@ -39,7 +39,7 @@ func (t *TestModel) Request(deviceID DeviceID, folder, name string, offset int64 return nil } -func (t *TestModel) Close(deviceID DeviceID, err error) { +func (t *TestModel) Closed(conn Connection, err error) { t.closedErr = err close(t.closedCh) } diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 96493e8ef..b147a1a2d 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -67,7 +67,7 @@ type Model interface { // A cluster configuration message was received ClusterConfig(deviceID DeviceID, config ClusterConfig) // The peer device closed the connection - Close(deviceID DeviceID, err error) + Closed(conn Connection, err error) // The peer device sent progress updates for the files it is currently downloading DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) } @@ -729,7 +729,7 @@ func (c *rawConnection) close(err error) { } c.awaitingMut.Unlock() - go c.receiver.Close(c.id, err) + c.receiver.Closed(c, err) }) }