Files
pocketbase/apis/sql.go

192 lines
4.5 KiB
Go

package apis
import (
"context"
"errors"
"log/slog"
"net/http"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/router"
)
const (
runSQLMaxRows = 1000
runSQLMaxTimeout = 3 * time.Minute
)
// bindSQLApi registers the SQL api endpoints.
func bindSQLApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
subGroup := rg.Group("/sql").Bind(RequireSuperuserAuth())
subGroup.POST("", runSQL)
}
func runSQL(e *core.RequestEvent) error {
// extra precaution in case manually invoked from somewhere else
if !e.HasSuperuserAuth() {
return e.ForbiddenError("", nil)
}
form := runSQLForm{}
err := e.BindBody(&form)
if err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
}
err = form.validate()
if err != nil {
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
}
result, err := executeQuery(e.App, form.Query, runSQLMaxRows)
if err != nil {
return firstApiError(err, e.BadRequestError("Failed to execute query. Raw error:\n"+err.Error(), nil))
}
return e.JSON(http.StatusOK, result)
}
type runSQLForm struct {
Query string `form:"query" json:"query"`
}
func (form *runSQLForm) validate() error {
return validation.ValidateStruct(form,
validation.Field(&form.Query, validation.Required, validation.Length(0, 5000)),
)
}
type runSQLResultColumn struct {
Name string `json:"name"`
Type string `json:"type"`
Nullable bool `json:"nullable"`
}
type runSQLResult struct {
ExecTime int64 `json:"execTime"`
AffectedRows int64 `json:"affectedRows"`
Columns []runSQLResultColumn `json:"columns"`
Rows [][]any `json:"rows"`
}
var knownWriteQueryPrefixes = []string{"INSERT", "CREATE", "UPDATE", "DELETE", "DROP", "DETACH"}
func executeQuery(app core.App, query string, maxRows int) (*runSQLResult, error) {
query = strings.TrimSpace(query)
if query == "" {
// see https://github.com/mattn/go-sqlite3/issues/950
return nil, errors.New("empty query")
}
var isPossibleWriteQuery bool
// loosely check the query type
ucQuery := strings.ToUpper(query)
if !strings.HasPrefix(ucQuery, "SELECT") {
for _, prefix := range knownWriteQueryPrefixes {
if strings.HasPrefix(ucQuery, prefix) {
isPossibleWriteQuery = true
break
}
}
}
// note: don't extend the request context to minimize the risk of
// causing integrity issues with custom non-transaction mutations
ctx, cancelFunc := context.WithTimeout(context.Background(), runSQLMaxTimeout)
defer cancelFunc()
result := &runSQLResult{
// init empty slices to ensure "[]" serialization
Columns: []runSQLResultColumn{},
Rows: [][]any{},
}
now := time.Now()
defer func() {
result.ExecTime = time.Since(now).Milliseconds()
}()
// assume write/mutation query
// ---------------------------------------------------------------
if isPossibleWriteQuery {
// auto wrap in transaction in case there are multiple inline queries
txErr := app.RunInTransaction(func(txApp core.App) error {
execResult, err := txApp.NonconcurrentDB().NewQuery(query).WithContext(ctx).Execute()
if err != nil {
return err
}
result.AffectedRows, err = execResult.RowsAffected()
if err != nil {
// non-critical error (e.g. not supported by the driver)
txApp.Logger().Debug("Unable to fetch affected rows", slog.String("error", err.Error()))
}
return nil
})
if txErr != nil {
return nil, txErr
}
return result, nil
}
// assume query returning rows
// ---------------------------------------------------------------
rows, err := app.ConcurrentDB().NewQuery(query).WithContext(ctx).Rows()
if err != nil {
return nil, err
}
defer rows.Close()
// populate columns info
// ---
colTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
for _, colType := range colTypes {
col := runSQLResultColumn{
Name: colType.Name(),
Type: colType.DatabaseTypeName(),
}
col.Nullable, _ = colType.Nullable()
result.Columns = append(result.Columns, col)
}
// populate rows
// ---
for rows.Next() {
if len(result.Rows) >= maxRows {
break
}
rowData := make([]any, len(colTypes))
for i := 0; i < len(colTypes); i++ {
var v *string
rowData[i] = &v
}
err := rows.Scan(rowData...)
if err != nil {
return nil, err
}
result.Rows = append(result.Rows, rowData)
}
err = rows.Err()
if err != nil {
return nil, err
}
return result, nil
}