diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 210d889f9..fa4746cbe 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -49,6 +49,10 @@ const ( // MaxBlockSize is the maximum block size allowed MaxBlockSize = 16 << MiB + // MaxRequestSize is the largest amount of data that can be read in a + // single request + MaxRequestSize = 2 * MaxBlockSize + // DesiredPerFileBlocks is the number of blocks we aim for per file DesiredPerFileBlocks = 2000 @@ -457,14 +461,6 @@ func (c *rawConnection) dispatcherLoop() (err error) { } } - switch msg := msg.(type) { - case *bep.Request: - err = checkFilename(msg.Name) - } - if err != nil { - return newProtocolError(err, msgContext) - } - switch msg := msg.(type) { case *bep.ClusterConfig: err = c.model.ClusterConfig(clusterConfigFromWire(msg)) @@ -484,6 +480,15 @@ func (c *rawConnection) dispatcherLoop() (err error) { err = c.handleIndexUpdate(idxUp) case *bep.Request: + if err := checkFilename(msg.Name); err != nil { + return newProtocolError(err, msgContext) + } + if msg.Size <= 0 { + return newProtocolError(fmt.Errorf("request size %d too small", msg.Size), msgContext) + } + if msg.Size > MaxRequestSize { + return newProtocolError(fmt.Errorf("request size %d exceeds maximum allowed", msg.Size), msgContext) + } go c.handleRequest(requestFromWire(msg)) case *bep.Response: diff --git a/lib/protocol/protocol_test.go b/lib/protocol/protocol_test.go index 4873ec987..7f3d23f1b 100644 --- a/lib/protocol/protocol_test.go +++ b/lib/protocol/protocol_test.go @@ -11,7 +11,9 @@ import ( "context" "encoding/hex" "errors" + "fmt" "io" + "strings" "sync" "testing" "time" @@ -541,6 +543,84 @@ func TestDispatcherToCloseDeadlock(t *testing.T) { } } +func TestRequestMaxSize(t *testing.T) { + invalidSize := []int{-65536, 0, MaxRequestSize + 1} + for _, s := range invalidSize { + t.Run(fmt.Sprintf("invalid/%d", s), func(t *testing.T) { + m := newTestModel() + rw := testutil.NewBlockingRW() + c := getRawConnection(NewConnection(c0ID, rw, &testutil.NoopRW{}, testutil.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, testKeyGen)) + c.Start() + defer closeAndWait(c, rw) + + c.inbox <- &bep.ClusterConfig{} + + // A request at exactly MaxRequestSize should be accepted. + c.inbox <- &bep.Request{ + Id: 1, + Name: "valid", + Size: MaxRequestSize, + } + + res := <-c.outbox + if msg, ok := res.msg.(*bep.Response); !ok || msg.Id != 1 { + t.Errorf("bad response %#v", msg) + } + + // A request with an invalid size should cause the dispatcher to + // return with a protocol error. + c.inbox <- &bep.Request{ + Id: 2, + Name: "invalid", + Size: int32(s), + } + + select { + case <-c.dispatcherLoopStopped: + case <-time.After(time.Second): + t.Fatal("timed out before dispatcher loop terminated") + } + + err := m.closedError() + if err == nil { + t.Fatal("expected connection to be closed with an error") + } + if !strings.Contains(err.Error(), "protocol error") { + t.Errorf("expected a protocol error, got %v", err) + } + }) + } +} + +func TestRequestInvalidFilename(t *testing.T) { + m := newTestModel() + rw := testutil.NewBlockingRW() + c := getRawConnection(NewConnection(c0ID, rw, &testutil.NoopRW{}, testutil.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, testKeyGen)) + c.Start() + defer closeAndWait(c, rw) + + c.inbox <- &bep.ClusterConfig{} + c.inbox <- &bep.Request{ + Id: 1, + Name: "../escape", + Size: 1024, + } + + select { + case <-c.dispatcherLoopStopped: + case <-time.After(time.Second): + t.Fatal("timed out before dispatcher loop terminated") + } + + err := m.closedError() + if err == nil { + t.Fatal("expected connection to be closed with an error") + } + if !strings.Contains(err.Error(), "protocol error") { + t.Errorf("expected a protocol error, got %v", err) + } +} + func TestIndexIDString(t *testing.T) { // Index ID is a 64 bit, zero padded hex integer. var i IndexID = 42