diff --git a/pkg/system/capabilities.go b/pkg/system/capabilities.go index bbff3b8e2..9141d5929 100644 --- a/pkg/system/capabilities.go +++ b/pkg/system/capabilities.go @@ -132,29 +132,32 @@ func (s *SystemState) getSystemCapabilities() string { } } - if cuda13DirExists { - s.systemCapabilities = nvidiaCuda13 - return s.systemCapabilities - } - - if cuda12DirExists { - s.systemCapabilities = nvidiaCuda12 - return s.systemCapabilities - } - + // No GPU detected → default capability if s.GPUVendor == "" { xlog.Info("Default capability (no GPU detected)", "env", capabilityEnv) s.systemCapabilities = defaultCapability return s.systemCapabilities } - // If vram is less than 4GB, let's default to CPU but warn the user that they can override that via env + // GPU detected but insufficient VRAM → default with warning if s.VRAM <= 4*1024*1024*1024 { xlog.Warn("VRAM is less than 4GB, defaulting to CPU", "env", capabilityEnv) s.systemCapabilities = defaultCapability return s.systemCapabilities } + // CUDA directories refine capability only for NVIDIA GPUs + if s.GPUVendor == Nvidia { + if cuda13DirExists { + s.systemCapabilities = nvidiaCuda13 + return s.systemCapabilities + } + if cuda12DirExists { + s.systemCapabilities = nvidiaCuda12 + return s.systemCapabilities + } + } + s.systemCapabilities = s.GPUVendor return s.systemCapabilities } diff --git a/pkg/system/capabilities_test.go b/pkg/system/capabilities_test.go new file mode 100644 index 000000000..a267cf611 --- /dev/null +++ b/pkg/system/capabilities_test.go @@ -0,0 +1,129 @@ +package system + +import ( + "os" + "runtime" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("getSystemCapabilities", func() { + const eightGB = 8 * 1024 * 1024 * 1024 + const twoGB = 2 * 1024 * 1024 * 1024 + + var ( + origEnv string + origCuda12 bool + origCuda13 bool + ) + + BeforeEach(func() { + if runtime.GOOS == "darwin" { + Skip("darwin short-circuits before reaching CUDA logic") + } + + origEnv = os.Getenv(capabilityEnv) + os.Unsetenv(capabilityEnv) + + origCuda12 = cuda12DirExists + origCuda13 = cuda13DirExists + }) + + AfterEach(func() { + cuda12DirExists = origCuda12 + cuda13DirExists = origCuda13 + + if origEnv != "" { + os.Setenv(capabilityEnv, origEnv) + } + }) + + type testCase struct { + gpuVendor string + vram uint64 + cuda12 bool + cuda13 bool + wantCapability string + wantTokens []string + } + + DescribeTable("capability detection", + func(tc testCase) { + cuda12DirExists = tc.cuda12 + cuda13DirExists = tc.cuda13 + + s := &SystemState{ + GPUVendor: tc.gpuVendor, + VRAM: tc.vram, + } + + Expect(s.getSystemCapabilities()).To(Equal(tc.wantCapability)) + Expect(s.BackendPreferenceTokens()).To(Equal(tc.wantTokens)) + }, + Entry("CUDA dir present but no GPU", testCase{ + gpuVendor: "", + vram: 0, + cuda12: true, + cuda13: false, + wantCapability: "default", + wantTokens: []string{"cpu"}, + }), + Entry("CUDA 12 with NVIDIA GPU", testCase{ + gpuVendor: Nvidia, + vram: eightGB, + cuda12: true, + cuda13: false, + wantCapability: "nvidia-cuda-12", + wantTokens: []string{"cuda", "vulkan", "cpu"}, + }), + Entry("CUDA 13 with NVIDIA GPU", testCase{ + gpuVendor: Nvidia, + vram: eightGB, + cuda12: false, + cuda13: true, + wantCapability: "nvidia-cuda-13", + wantTokens: []string{"cuda", "vulkan", "cpu"}, + }), + Entry("Both CUDA dirs with NVIDIA GPU prefers 13", testCase{ + gpuVendor: Nvidia, + vram: eightGB, + cuda12: true, + cuda13: true, + wantCapability: "nvidia-cuda-13", + wantTokens: []string{"cuda", "vulkan", "cpu"}, + }), + Entry("CUDA dir with AMD GPU ignored", testCase{ + gpuVendor: AMD, + vram: eightGB, + cuda12: true, + cuda13: false, + wantCapability: "amd", + wantTokens: []string{"rocm", "hip", "vulkan", "cpu"}, + }), + Entry("No CUDA dir and no GPU", testCase{ + gpuVendor: "", + vram: 0, + cuda12: false, + cuda13: false, + wantCapability: "default", + wantTokens: []string{"cpu"}, + }), + Entry("No CUDA dir with NVIDIA GPU", testCase{ + gpuVendor: Nvidia, + vram: eightGB, + cuda12: false, + cuda13: false, + wantCapability: "nvidia", + wantTokens: []string{"cuda", "vulkan", "cpu"}, + }), + Entry("CUDA dir with NVIDIA GPU but low VRAM", testCase{ + gpuVendor: Nvidia, + vram: twoGB, + cuda12: true, + cuda13: false, + wantCapability: "default", + wantTokens: []string{"cpu"}, + }), + ) +}) diff --git a/pkg/system/system_suite_test.go b/pkg/system/system_suite_test.go new file mode 100644 index 000000000..9c9fded36 --- /dev/null +++ b/pkg/system/system_suite_test.go @@ -0,0 +1,13 @@ +package system + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestSystem(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "System test suite") +}