diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index af7a773e8..284df12d1 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -437,82 +437,61 @@ func (c *rawConnection) dispatcherLoop() (err error) { case <-c.closed: return ErrClosed } + + msgContext, err := messageContext(msg) + if err != nil { + return fmt.Errorf("protocol error: %w", err) + } + l.Debugf("handle %v message", msgContext) + switch msg := msg.(type) { case *ClusterConfig: - l.Debugln("read ClusterConfig message") if state == stateInitial { state = stateReady } - if err := c.receiver.ClusterConfig(c.id, *msg); err != nil { - return fmt.Errorf("receiving cluster config: %w", err) - } - - case *Index: - l.Debugln("read Index message") + case *Close: + return fmt.Errorf("closed by remote: %v", msg.Reason) + default: if state != stateReady { - return fmt.Errorf("protocol error: index message in state %d", state) + return newProtocolError(fmt.Errorf("invalid state %d", state), msgContext) } - if err := checkIndexConsistency(msg.Files); err != nil { - return errors.Wrap(err, "protocol error: index") - } - if err := c.handleIndex(*msg); err != nil { - return fmt.Errorf("receiving index: %w", err) - } - state = stateReady + } + + switch msg := msg.(type) { + case *Index: + err = checkIndexConsistency(msg.Files) case *IndexUpdate: - l.Debugln("read IndexUpdate message") - if state != stateReady { - return fmt.Errorf("protocol error: index update message in state %d", state) - } - if err := checkIndexConsistency(msg.Files); err != nil { - return errors.Wrap(err, "protocol error: index update") - } - if err := c.handleIndexUpdate(*msg); err != nil { - return fmt.Errorf("receiving index update: %w", err) - } - state = stateReady + err = checkIndexConsistency(msg.Files) + + case *Request: + err = checkFilename(msg.Name) + } + if err != nil { + return newProtocolError(err, msgContext) + } + + switch msg := msg.(type) { + case *ClusterConfig: + err = c.receiver.ClusterConfig(c.id, *msg) + + case *Index: + err = c.handleIndex(*msg) + + case *IndexUpdate: + err = c.handleIndexUpdate(*msg) case *Request: - l.Debugln("read Request message") - if state != stateReady { - return fmt.Errorf("protocol error: request message in state %d", state) - } - if err := checkFilename(msg.Name); err != nil { - return errors.Wrapf(err, "protocol error: request: %q", msg.Name) - } go c.handleRequest(*msg) case *Response: - l.Debugln("read Response message") - if state != stateReady { - return fmt.Errorf("protocol error: response message in state %d", state) - } c.handleResponse(*msg) case *DownloadProgress: - l.Debugln("read DownloadProgress message") - if state != stateReady { - return fmt.Errorf("protocol error: response message in state %d", state) - } - if err := c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates); err != nil { - return fmt.Errorf("receiving download progress: %w", err) - } - - case *Ping: - l.Debugln("read Ping message") - if state != stateReady { - return fmt.Errorf("protocol error: ping message in state %d", state) - } - // Nothing - - case *Close: - l.Debugln("read Close message") - return fmt.Errorf("closed by remote: %v", msg.Reason) - - default: - l.Debugf("read unknown message: %+T", msg) - return fmt.Errorf("protocol error: %s: unknown or empty message", c.id) + err = c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates) + } + if err != nil { + return newHandleError(err, msgContext) } } } @@ -1078,3 +1057,34 @@ func (c *rawConnection) lz4Decompress(src []byte) ([]byte, error) { } return decoded, nil } + +func newProtocolError(err error, msgContext string) error { + return fmt.Errorf("protocol error on %v: %w", msgContext, err) +} + +func newHandleError(err error, msgContext string) error { + return fmt.Errorf("handling %v: %w", msgContext, err) +} + +func messageContext(msg message) (string, error) { + switch msg := msg.(type) { + case *ClusterConfig: + return "cluster-config", nil + case *Index: + return fmt.Sprintf("index for %v", msg.Folder), nil + case *IndexUpdate: + return fmt.Sprintf("index-update for %v", msg.Folder), nil + case *Request: + return fmt.Sprintf(`request for "%v" in %v`, msg.Name, msg.Folder), nil + case *Response: + return "response", nil + case *DownloadProgress: + return fmt.Sprintf("download-progress for %v", msg.Folder), nil + case *Ping: + return "ping", nil + case *Close: + return "close", nil + default: + return "", errors.New("unknown or empty message") + } +}