[#7575] use memory+file buffer when rereading the request body (fix #7572)

This commit is contained in:
Gani Georgiev
2026-03-09 17:19:09 +02:00
parent 93e3eb3a35
commit ba8b51af58
9 changed files with 454 additions and 35 deletions

View File

@@ -1,3 +1,11 @@
## v0.36.7 (WIP)
- Fixes high memory usage with large file uploads ([#7572](https://github.com/pocketbase/pocketbase/discussions/7572)).
- (@todo) Updated `modernc.org/sqlite` to v1.47.0 (SQLite v3.52.0).
_It fixes a [database corruption bug](https://sqlite.org/wal.html#walresetbug) that it is very difficult to trigger but still it is advised to upgrade._
## v0.36.6
- Set `NumberField.OnlyInt:true` for the generated View collection schema fields when a view column expression is known to return int-only values ([#7538](https://github.com/pocketbase/pocketbase/issues/7538)).

View File

@@ -364,6 +364,7 @@ func processInternalRequest(
// assign request
event.Request = r
event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads
defer event.Request.Body.Close()
// assign response
rec := httptest.NewRecorder()

View File

@@ -112,9 +112,21 @@ func (r *limitedReader) Read(b []byte) (int, error) {
return n, nil
}
// explicit casts to ensure that the main struct methods will be invoked
// (extra precautions in case of nested interface wrapping erasure)
// ---
func (r *limitedReader) Reread() {
rr, ok := r.ReadCloser.(router.Rereader)
rereader, ok := r.ReadCloser.(router.Rereader)
if ok {
rr.Reread()
rereader.Reread()
}
}
func (r *limitedReader) Close() error {
closer, ok := r.ReadCloser.(io.Closer)
if ok {
return closer.Close()
}
return nil
}

View File

@@ -0,0 +1,125 @@
package router
import (
"bytes"
"errors"
"io"
"os"
)
var _ io.ReadWriteCloser = (*bufferWithFile)(nil)
// newBufferWithFile initializes and returns a new bufferWithFile with the specified memoryLimit.
//
// If memoryLimit is negative or zero, defaults to [DefaultMaxMemory].
func newBufferWithFile(memoryLimit int64) *bufferWithFile {
if memoryLimit <= 0 {
memoryLimit = DefaultMaxMemory
}
return &bufferWithFile{
buf: new(bytes.Buffer),
memoryLimit: memoryLimit,
}
}
// bufferWithFile is similar to [bytes.Buffer] but after the limit it
// fallbacks to a temporary file to minimize excessive memory usage.
type bufferWithFile struct {
buf *bytes.Buffer
file *os.File
memoryLimit int64
fileReadOffset int64
}
// Read implements the standard [io.Reader] interface by reading
// up to len(p) bytes into p.
func (b *bufferWithFile) Read(p []byte) (n int, err error) {
if b.buf == nil {
return 0, errors.New("[bufferWithFile.Read] not initialized or already closed")
}
// eagerly get length because bytes.Buffer may resize and change it
maxToRead := len(p)
// read first from the memory buffer
if b.buf.Len() > 0 {
n, err = b.buf.Read(p)
if err != nil && err != io.EOF {
return n, err
}
}
// continue reading from the file to fill the remaining bytes
if n < maxToRead && b.file != nil {
fileN, fileErr := b.file.ReadAt(p[n:maxToRead], b.fileReadOffset)
b.fileReadOffset += int64(fileN)
n += fileN
err = fileErr
}
// return EOF if the buffers are empty and nothing has been read
// (to minimize potential breaking changes and for consistency with the bytes.Buffer rules)
if n == 0 && maxToRead > 0 && err == nil {
return 0, io.EOF
}
return n, err
}
// Write implements the standard [io.Writer] interface by writing the
// content of p into the buffer.
//
// If the current memory buffer doesn't have enough space to hold len(p),
// it write p into a temp disk file.
func (b *bufferWithFile) Write(p []byte) (int, error) {
if b.buf == nil {
return 0, errors.New("[bufferWithFile.Write] not initialized or already closed")
}
// already above the limit -> continue with the file
if b.file != nil {
return b.file.Write(p)
}
// above limit -> create and write to file
if int64(b.buf.Len()+len(p)) > b.memoryLimit {
if b.file == nil {
var err error
b.file, err = os.CreateTemp("", "pb_buffer_file_*")
if err != nil {
return 0, err
}
}
return b.file.Write(p)
}
// write in memory
return b.buf.Write(p)
}
// Close implements the standard [io.Closer] interface.
//
// It unsets the memory buffer and will cleanup after the fallback
// temporary file (if exists).
//
// It is safe to call Close multiple times.
// Once Close is invoked the buffer no longer can be used and should be discarded.
func (b *bufferWithFile) Close() error {
if b.file != nil {
err := errors.Join(
b.file.Close(),
os.Remove(b.file.Name()),
)
if err != nil {
return err
}
b.file = nil
}
b.buf = nil
return nil
}

View File

@@ -0,0 +1,193 @@
package router
import (
"errors"
"io"
"io/fs"
"os"
"testing"
)
func TestNewBufferWithFile(t *testing.T) {
t.Parallel()
scenarios := []struct {
name string
limit int64
expected int64
}{
{"negative limit", -1, DefaultMaxMemory},
{"zero limit", 0, DefaultMaxMemory},
{"> 0", 1, 1},
}
for _, s := range scenarios {
t.Run(s.name, func(t *testing.T) {
b := newBufferWithFile(s.limit)
if b.file != nil {
t.Fatalf("Expected no file descriptor to be open, got %v", b.file)
}
if b.buf == nil {
t.Fatal("Expected buf to be initialized, got nil")
}
if b.memoryLimit != s.expected {
t.Fatalf("Expected %d limit, got %d", 10, b.memoryLimit)
}
})
}
}
func TestBufferWithFile_WriteReadClose(t *testing.T) {
t.Parallel()
b := newBufferWithFile(4)
t.Run("write under limit", func(t *testing.T) {
n, err := b.Write([]byte("ab"))
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("Expected %d bytes to be written, got %v", 2, n)
}
if l := b.buf.Len(); l != 2 {
t.Fatalf("Expected memory buf lenth %d, got %d", 2, l)
}
if b.file != nil {
t.Fatalf("Expected temp file to remain nil, got %v", b.file)
}
})
t.Run("write under limit (again)", func(t *testing.T) {
n, err := b.Write([]byte("c"))
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Fatalf("Expected %d bytes to be written, got %v", 1, n)
}
if l := b.buf.Len(); l != 3 {
t.Fatalf("Expected memory buf lenth %d, got %d", 3, l)
}
if b.file != nil {
t.Fatalf("Expected temp file to remain nil, got %v", b.file)
}
})
t.Run("write beyound limit (aka. skip memory buf and write into file)", func(t *testing.T) {
n, err := b.Write([]byte("de"))
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("Expected %d bytes to be written, got %v", 2, n)
}
if l := b.buf.Len(); l != 3 {
t.Fatalf("Expected memory buf lenth to be unchanged (%d), got %d", 3, l)
}
if b.file == nil {
t.Fatal("Expected temp file to be initialized")
}
})
t.Run("read 0 bytes fromm non-empty buffer", func(t *testing.T) {
temp := []byte{}
n, err := b.Read(temp)
if err != nil { // should return nil for consistency with bytes.Buffer
t.Fatalf("Expected nil, got %v", err)
}
if n != 0 {
t.Fatalf("Expected 0 bytes to be read, got %d (%q)", n, temp)
}
})
t.Run("read under limit", func(t *testing.T) {
expected := "ab"
temp := make([]byte, 2)
n, err := b.Read(temp)
if err != nil && err != io.EOF {
t.Fatal(err)
}
if n != len(temp) {
t.Fatalf("Expected %d bytes to be read, got %d (%q)", len(temp), n, temp)
}
if str := string(temp); str != expected {
t.Fatalf("Expected to read %q, got %q", expected, str)
}
})
t.Run("read beyound limit", func(t *testing.T) {
expected := "cde"
temp := make([]byte, 3)
n, err := b.Read(temp)
if err != nil && err != io.EOF {
t.Fatal(err)
}
if n != len(temp) {
t.Fatalf("Expected %d bytes to be read, got %d (%q)", len(temp), n, temp)
}
if str := string(temp); str != expected {
t.Fatalf("Expected to read %q, got %q", expected, str)
}
})
t.Run("read from empty buffers", func(t *testing.T) {
temp := make([]byte, 3)
n, err := b.Read(temp)
if err != io.EOF {
t.Fatalf("Expected EOF, got %v", err)
}
if n != 0 {
t.Fatalf("Expected 0 bytes to be read, got %d (%q)", n, temp)
}
})
t.Run("close cleanup", func(t *testing.T) {
if b.file == nil {
t.Fatal("Expected temp file to be initialized, got nil")
}
filename := b.file.Name()
_, err := os.Stat(filename)
if err != nil || errors.Is(err, fs.ErrNotExist) {
t.Fatalf("Expected the temp file to exist and be accessible, got %v", err)
}
err = b.Close()
if err != nil {
t.Fatal(err)
}
info, err := os.Stat(filename)
if err == nil {
t.Fatalf("Expected the temp file to be deleted after close, got %v", info)
}
if b.buf != nil || b.file != nil {
t.Fatal("Expected the internal buffers to be nil after close")
}
})
}

View File

@@ -322,7 +322,7 @@ func (e *Event) InternalServerError(message string, errData any) *ApiError {
// Binders
// -------------------------------------------------------------------
const DefaultMaxMemory = 32 << 20 // 32mb
const DefaultMaxMemory = 16 << 20 // 16mb
// BindBody unmarshal the request body into the provided dst.
//

View File

@@ -1,13 +1,13 @@
package router
import (
"bytes"
"errors"
"io"
)
var (
_ io.ReadCloser = (*RereadableReadCloser)(nil)
_ Rereader = (*RereadableReadCloser)(nil)
_ io.ReadCloser = (*RereadableReadCloser)(nil)
)
// Rereader defines an interface for rewindable readers.
@@ -15,30 +15,44 @@ type Rereader interface {
Reread()
}
// RereadableReadCloser defines a wrapper around a io.ReadCloser reader
// RereadableReadCloser defines a wrapper around a [io.ReadCloser] reader
// allowing to read the original reader multiple times.
//
// NB! Make sure to call Close after done working with the reader.
type RereadableReadCloser struct {
io.ReadCloser
copy *bytes.Buffer
active io.Reader
copy io.ReadWriteCloser
closeErrors []error
// MaxMemory specifies the max size of the in memory copy buffer
// before switching to read/write from temp disk file.
//
// If negative or zero, defaults to [DefaultMaxMemory].
MaxMemory int64
}
// Read implements the standard io.Reader interface.
// Read implements the standard [io.Reader] interface.
//
// It reads up to len(b) bytes into b and at at the same time writes
// the read data into an internal bytes buffer.
// It reads up to len(p) bytes into p and and at the same time copies
// the read data into an internal buffer (memory + temp file).
//
// On EOF the r is "rewinded" to allow reading from r multiple times.
func (r *RereadableReadCloser) Read(b []byte) (int, error) {
if r.active == nil {
// On EOF r is "rewinded" to allow reading multiple times.
func (r *RereadableReadCloser) Read(p []byte) (int, error) {
n, err := r.ReadCloser.Read(p)
// copy the read bytes into the internal buffer
if n > 0 {
if r.copy == nil {
r.copy = &bytes.Buffer{}
r.copy = newBufferWithFile(r.MaxMemory)
}
if n, err := r.copy.Write(p[:n]); err != nil {
return n, err
}
r.active = io.TeeReader(r.ReadCloser, r.copy)
}
n, err := r.active.Read(b)
// end reached -> reset for the next read
if err == io.EOF {
r.Reread()
}
@@ -50,11 +64,33 @@ func (r *RereadableReadCloser) Read(b []byte) (int, error) {
//
// note: not named Reset to avoid conflicts with other reader interfaces.
func (r *RereadableReadCloser) Reread() {
if r.copy == nil || r.copy.Len() == 0 {
return // nothing to reset or it has been already reset
if r.copy == nil {
return // nothing to reset
}
oldCopy := r.copy
r.copy = &bytes.Buffer{}
r.active = io.TeeReader(oldCopy, r.copy)
// eagerly close the old reader to prevent accumulating too much memory or temp files
if err := r.ReadCloser.Close(); err != nil {
r.closeErrors = append(r.closeErrors, err)
}
r.ReadCloser = r.copy
r.copy = nil
}
// Close implements the standard [io.Closer] interface by cleaning up related resources.
//
// It is safe to call Close multiple times.
// Once Close is invoked the reader no longer can be used and should be discarded.
func (r *RereadableReadCloser) Close() error {
if r.copy != nil {
if err := r.copy.Close(); err != nil {
r.closeErrors = append(r.closeErrors, err)
}
}
if err := r.ReadCloser.Close(); err != nil {
r.closeErrors = append(r.closeErrors, err)
}
return errors.Join(r.closeErrors...)
}

View File

@@ -1,28 +1,69 @@
package router_test
package router
import (
"errors"
"io"
"io/fs"
"os"
"strconv"
"strings"
"testing"
"github.com/pocketbase/pocketbase/tools/router"
)
func TestRereadableReadCloser(t *testing.T) {
content := "test"
rereadable := &router.RereadableReadCloser{
rereadable := &RereadableReadCloser{
ReadCloser: io.NopCloser(strings.NewReader(content)),
MaxMemory: 2, // should store the rest 2 bytes in temp file
}
// read multiple times
for i := 0; i < 3; i++ {
result, err := io.ReadAll(rereadable)
if err != nil {
t.Fatalf("[read:%d] %v", i, err)
}
if str := string(result); str != content {
t.Fatalf("[read:%d] Expected %q, got %q", i, content, result)
totalRereads := 5
tempFilenames := make([]string, 0, totalRereads)
// reread multiple times
for i := 0; i < totalRereads; i++ {
t.Run("run_"+strconv.Itoa(i), func(t *testing.T) {
if i > 3 {
// test allso with manual Reread calls to ensure that
// r.copy is reseted and written to only when there are n>0 bytes
rereadable.Reread()
}
result, err := io.ReadAll(rereadable)
if err != nil {
t.Fatalf("[read:%d] %v", i, err)
}
if str := string(result); str != content {
t.Fatalf("[read:%d] Expected %q, got %q", i, content, result)
}
b, ok := rereadable.ReadCloser.(*bufferWithFile)
if !ok {
t.Fatalf("Expected bufferWithFile replacement, got %v", b)
}
if b.file != nil {
tempFilenames = append(tempFilenames, b.file.Name())
}
})
}
if v := len(tempFilenames); v != totalRereads {
t.Fatalf("Expected %d temp files to have been created during the previous rereads, got %d", totalRereads, v)
}
err := rereadable.Close()
if err != nil {
t.Fatalf("Expected no close errors, got %v", err)
}
// ensure that no lingering temp files are left after close
for _, name := range tempFilenames {
info, err := os.Stat(name)
if err == nil || !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("Expected file name %q to be deleted, got %v (%v)", name, info, err)
}
}
}

View File

@@ -132,7 +132,10 @@ func (r *Router[T]) loadMux(mux *http.ServeMux, group *RouterGroup[T], parents [
resp = &ResponseWriter{ResponseWriter: resp}
// wrap the request body to allow multiple reads
req.Body = &RereadableReadCloser{ReadCloser: req.Body}
body := &RereadableReadCloser{ReadCloser: req.Body}
defer body.Close()
req.Body = body
event, cleanupFunc := r.eventFactory(resp, req)