mirror of
https://github.com/pocketbase/pocketbase.git
synced 2026-05-30 11:45:49 -04:00
192 lines
4.5 KiB
Go
192 lines
4.5 KiB
Go
package apis
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
validation "github.com/go-ozzo/ozzo-validation/v4"
|
|
"github.com/pocketbase/pocketbase/core"
|
|
"github.com/pocketbase/pocketbase/tools/router"
|
|
)
|
|
|
|
const (
|
|
runSQLMaxRows = 1000
|
|
runSQLMaxTimeout = 3 * time.Minute
|
|
)
|
|
|
|
// bindSQLApi registers the SQL api endpoints.
|
|
func bindSQLApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
|
|
subGroup := rg.Group("/sql").Bind(RequireSuperuserAuth())
|
|
subGroup.POST("", runSQL)
|
|
}
|
|
|
|
func runSQL(e *core.RequestEvent) error {
|
|
// extra precaution in case manually invoked from somewhere else
|
|
if !e.HasSuperuserAuth() {
|
|
return e.ForbiddenError("", nil)
|
|
}
|
|
|
|
form := runSQLForm{}
|
|
|
|
err := e.BindBody(&form)
|
|
if err != nil {
|
|
return firstApiError(err, e.BadRequestError("An error occurred while loading the submitted data.", err))
|
|
}
|
|
|
|
err = form.validate()
|
|
if err != nil {
|
|
return firstApiError(err, e.BadRequestError("An error occurred while validating the submitted data.", err))
|
|
}
|
|
|
|
result, err := executeQuery(e.App, form.Query, runSQLMaxRows)
|
|
if err != nil {
|
|
return firstApiError(err, e.BadRequestError("Failed to execute query. Raw error:\n"+err.Error(), nil))
|
|
}
|
|
|
|
return e.JSON(http.StatusOK, result)
|
|
}
|
|
|
|
type runSQLForm struct {
|
|
Query string `form:"query" json:"query"`
|
|
}
|
|
|
|
func (form *runSQLForm) validate() error {
|
|
return validation.ValidateStruct(form,
|
|
validation.Field(&form.Query, validation.Required, validation.Length(0, 5000)),
|
|
)
|
|
}
|
|
|
|
type runSQLResultColumn struct {
|
|
Name string `json:"name"`
|
|
Type string `json:"type"`
|
|
Nullable bool `json:"nullable"`
|
|
}
|
|
|
|
type runSQLResult struct {
|
|
ExecTime int64 `json:"execTime"`
|
|
AffectedRows int64 `json:"affectedRows"`
|
|
Columns []runSQLResultColumn `json:"columns"`
|
|
Rows [][]any `json:"rows"`
|
|
}
|
|
|
|
var knownWriteQueryPrefixes = []string{"INSERT", "CREATE", "UPDATE", "DELETE", "DROP", "DETACH"}
|
|
|
|
func executeQuery(app core.App, query string, maxRows int) (*runSQLResult, error) {
|
|
query = strings.TrimSpace(query)
|
|
if query == "" {
|
|
// see https://github.com/mattn/go-sqlite3/issues/950
|
|
return nil, errors.New("empty query")
|
|
}
|
|
|
|
var isPossibleWriteQuery bool
|
|
|
|
// loosely check the query type
|
|
ucQuery := strings.ToUpper(query)
|
|
if !strings.HasPrefix(ucQuery, "SELECT") {
|
|
for _, prefix := range knownWriteQueryPrefixes {
|
|
if strings.HasPrefix(ucQuery, prefix) {
|
|
isPossibleWriteQuery = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// note: don't extend the request context to minimize the risk of
|
|
// causing integrity issues with custom non-transaction mutations
|
|
ctx, cancelFunc := context.WithTimeout(context.Background(), runSQLMaxTimeout)
|
|
defer cancelFunc()
|
|
|
|
result := &runSQLResult{
|
|
// init empty slices to ensure "[]" serialization
|
|
Columns: []runSQLResultColumn{},
|
|
Rows: [][]any{},
|
|
}
|
|
|
|
now := time.Now()
|
|
defer func() {
|
|
result.ExecTime = time.Since(now).Milliseconds()
|
|
}()
|
|
|
|
// assume write/mutation query
|
|
// ---------------------------------------------------------------
|
|
if isPossibleWriteQuery {
|
|
// auto wrap in transaction in case there are multiple inline queries
|
|
txErr := app.RunInTransaction(func(txApp core.App) error {
|
|
execResult, err := txApp.NonconcurrentDB().NewQuery(query).WithContext(ctx).Execute()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result.AffectedRows, err = execResult.RowsAffected()
|
|
if err != nil {
|
|
// non-critical error (e.g. not supported by the driver)
|
|
txApp.Logger().Debug("Unable to fetch affected rows", slog.String("error", err.Error()))
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if txErr != nil {
|
|
return nil, txErr
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// assume query returning rows
|
|
// ---------------------------------------------------------------
|
|
rows, err := app.ConcurrentDB().NewQuery(query).WithContext(ctx).Rows()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// populate columns info
|
|
// ---
|
|
colTypes, err := rows.ColumnTypes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, colType := range colTypes {
|
|
col := runSQLResultColumn{
|
|
Name: colType.Name(),
|
|
Type: colType.DatabaseTypeName(),
|
|
}
|
|
col.Nullable, _ = colType.Nullable()
|
|
|
|
result.Columns = append(result.Columns, col)
|
|
}
|
|
|
|
// populate rows
|
|
// ---
|
|
for rows.Next() {
|
|
if len(result.Rows) >= maxRows {
|
|
break
|
|
}
|
|
|
|
rowData := make([]any, len(colTypes))
|
|
for i := 0; i < len(colTypes); i++ {
|
|
var v *string
|
|
rowData[i] = &v
|
|
}
|
|
|
|
err := rows.Scan(rowData...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result.Rows = append(result.Rows, rowData)
|
|
}
|
|
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|