diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 3575fee2b..86c7bc2dc 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -23,10 +23,15 @@ import ( "github.com/mudler/LocalAI/core/backend" model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" ) func downloadFile(url string) (string, error) { + if err := utils.ValidateExternalURL(url); err != nil { + return "", fmt.Errorf("URL validation failed: %w", err) + } + // Get the data resp, err := http.Get(url) if err != nil { diff --git a/pkg/utils/base64.go b/pkg/utils/base64.go index 2d22a27be..905495a18 100644 --- a/pkg/utils/base64.go +++ b/pkg/utils/base64.go @@ -21,6 +21,10 @@ var dataURIPattern = regexp.MustCompile(`^data:([^;]+);base64,`) // GetContentURIAsBase64 checks if the string is an URL, if it's an URL downloads the content in memory encodes it in base64 and returns the base64 string, otherwise returns the string by stripping base64 data headers func GetContentURIAsBase64(s string) (string, error) { if strings.HasPrefix(s, "http") || strings.HasPrefix(s, "https") { + if err := ValidateExternalURL(s); err != nil { + return "", fmt.Errorf("URL validation failed: %w", err) + } + // download the image resp, err := base64DownloadClient.Get(s) if err != nil { diff --git a/pkg/utils/urlfetch.go b/pkg/utils/urlfetch.go new file mode 100644 index 000000000..d32a1ba0a --- /dev/null +++ b/pkg/utils/urlfetch.go @@ -0,0 +1,78 @@ +package utils + +import ( + "fmt" + "net" + "net/url" + "strings" +) + +// ValidateExternalURL checks that the given URL does not point to a private, +// loopback, link-local, or otherwise internal network address. This prevents +// Server-Side Request Forgery (SSRF) attacks where a user-supplied URL could +// be used to probe internal services or cloud metadata endpoints. +func ValidateExternalURL(rawURL string) error { + parsed, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + + scheme := strings.ToLower(parsed.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("unsupported URL scheme: %s", scheme) + } + + hostname := parsed.Hostname() + if hostname == "" { + return fmt.Errorf("URL has no hostname") + } + + // Block well-known internal hostnames + lower := strings.ToLower(hostname) + if lower == "localhost" || strings.HasSuffix(lower, ".local") { + return fmt.Errorf("requests to internal hosts are not allowed") + } + + // Block cloud metadata service hostnames + if lower == "metadata.google.internal" || lower == "instance-data" { + return fmt.Errorf("requests to cloud metadata services are not allowed") + } + + ips, err := net.LookupHost(hostname) + if err != nil { + return fmt.Errorf("failed to resolve hostname: %w", err) + } + + for _, ipStr := range ips { + ip := net.ParseIP(ipStr) + if ip == nil { + return fmt.Errorf("unable to parse resolved IP: %s", ipStr) + } + + if !isPublicIP(ip) { + return fmt.Errorf("requests to internal network addresses are not allowed") + } + } + + return nil +} + +func isPublicIP(ip net.IP) bool { + if ip.IsLoopback() || + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() || + ip.IsPrivate() || + ip.IsUnspecified() { + return false + } + + // Block IPv4-mapped IPv6 addresses that wrap private IPv4 + if ip4 := ip.To4(); ip4 != nil { + return !ip4.IsLoopback() && + !ip4.IsLinkLocalUnicast() && + !ip4.IsPrivate() && + !ip4.IsUnspecified() + } + + return true +} diff --git a/pkg/utils/urlfetch_test.go b/pkg/utils/urlfetch_test.go new file mode 100644 index 000000000..62c409e5f --- /dev/null +++ b/pkg/utils/urlfetch_test.go @@ -0,0 +1,99 @@ +package utils_test + +import ( + . "github.com/mudler/LocalAI/pkg/utils" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("utils/urlfetch tests", func() { + Context("ValidateExternalURL", func() { + It("allows valid external HTTPS URLs", func() { + err := ValidateExternalURL("https://example.com/image.png") + Expect(err).To(BeNil()) + }) + + It("allows valid external HTTP URLs", func() { + err := ValidateExternalURL("http://example.com/image.png") + Expect(err).To(BeNil()) + }) + + It("blocks localhost", func() { + err := ValidateExternalURL("http://localhost/secret") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + + It("blocks 127.0.0.1", func() { + err := ValidateExternalURL("http://127.0.0.1/secret") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + + It("blocks private 10.x.x.x range", func() { + err := ValidateExternalURL("http://10.0.0.1/secret") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + + It("blocks private 172.16.x.x range", func() { + err := ValidateExternalURL("http://172.16.0.1/secret") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + + It("blocks private 192.168.x.x range", func() { + err := ValidateExternalURL("http://192.168.1.1/secret") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + + It("blocks link-local 169.254.x.x (AWS metadata)", func() { + err := ValidateExternalURL("http://169.254.169.254/latest/meta-data/") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + + It("blocks unsupported schemes", func() { + err := ValidateExternalURL("ftp://example.com/file") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("unsupported URL scheme")) + }) + + It("blocks file:// scheme", func() { + err := ValidateExternalURL("file:///etc/passwd") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("unsupported URL scheme")) + }) + + It("blocks URLs with no hostname", func() { + err := ValidateExternalURL("http:///path") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("no hostname")) + }) + + It("blocks .local hostnames", func() { + err := ValidateExternalURL("http://myservice.local/api") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + + It("blocks metadata.google.internal", func() { + err := ValidateExternalURL("http://metadata.google.internal/computeMetadata/v1/") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("metadata")) + }) + + It("blocks 0.0.0.0", func() { + err := ValidateExternalURL("http://0.0.0.0/") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + + It("blocks IPv6 loopback ::1", func() { + err := ValidateExternalURL("http://[::1]/secret") + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("internal")) + }) + }) +})