From d7a80dda1e19416d5cda562650f28039986e2ab4 Mon Sep 17 00:00:00 2001 From: Shiv Tyagi Date: Thu, 12 Feb 2026 12:11:47 +0000 Subject: [PATCH] Discover vendor from cdi spec before injecting CDI device for --gpu option Signed-off-by: Shiv Tyagi --- docs/source/markdown/options/gpus.md | 3 +- libpod/container_config.go | 2 + libpod/container_internal_common.go | 21 +++- libpod/options.go | 11 ++ libpod/util.go | 53 +++++++++ libpod/util_test.go | 145 +++++++++++++++++++++++ pkg/specgen/generate/container_create.go | 3 + pkg/specgen/specgen.go | 4 + pkg/specgenutil/specgen.go | 7 +- 9 files changed, 240 insertions(+), 9 deletions(-) diff --git a/docs/source/markdown/options/gpus.md b/docs/source/markdown/options/gpus.md index 4c279e3880..b5673a2f4b 100644 --- a/docs/source/markdown/options/gpus.md +++ b/docs/source/markdown/options/gpus.md @@ -4,5 +4,4 @@ ####> are applicable to all of those. #### **--gpus**=*ENTRY* -GPU devices to add to the container ('all' to pass all GPUs) Currently only -Nvidia devices are supported. +Start the container with GPU support. Where `ENTRY` can be `all` to request all GPUs, or a vendor-specific identifier. Currently, NVIDIA and AMD devices are supported. If both NVIDIA and AMD devices are present, the NVIDIA devices will be preferred, and a CDI device name must be specified using the `--device` flag to request a set of GPUs from a *specific* vendor. diff --git a/libpod/container_config.go b/libpod/container_config.go index 850589092c..05d5715309 100644 --- a/libpod/container_config.go +++ b/libpod/container_config.go @@ -442,6 +442,8 @@ type ContainerMiscConfig struct { PidFile string `json:"pid_file,omitempty"` // CDIDevices contains devices that use the CDI CDIDevices []string `json:"cdiDevices,omitempty"` + // GPUs contains gpus which eventually get resolved to CDI devices + GPUs []string `json:"gpus,omitempty"` // DeviceHostSrc contains the original source on the host DeviceHostSrc []spec.LinuxDevice `json:"device_host_src,omitempty"` // EnvSecrets are secrets that are set as environment variables diff --git a/libpod/container_internal_common.go b/libpod/container_internal_common.go index af1542b829..0d88d9ec1d 100644 --- a/libpod/container_internal_common.go +++ b/libpod/container_internal_common.go @@ -712,7 +712,8 @@ func (c *Container) generateSpec(ctx context.Context) (s *spec.Spec, cleanupFunc } // Warning: CDI may alter g.Config in place. - if len(c.config.CDIDevices) > 0 { + // GPUs are also handled via CDI. + if len(c.config.CDIDevices) > 0 || len(c.config.GPUs) > 0 { registry, err := cdi.NewCache( cdi.WithSpecDirs(c.runtime.config.Engine.CdiSpecDirs.Get()...), cdi.WithAutoRefresh(false), @@ -723,7 +724,12 @@ func (c *Container) generateSpec(ctx context.Context) (s *spec.Spec, cleanupFunc if err := registry.Refresh(); err != nil { logrus.Debugf("The following error was triggered when refreshing the CDI registry: %v", err) } - if _, err := registry.InjectDevices(g.Config, c.config.CDIDevices...); err != nil { + + cdiDevices, err := getAllCDIDeviceNames(registry, c.config) + if err != nil { + return nil, nil, fmt.Errorf("getting CDI device names: %w", err) + } + if _, err := registry.InjectDevices(g.Config, cdiDevices...); err != nil { return nil, nil, fmt.Errorf("setting up CDI devices: %w", err) } } @@ -3166,3 +3172,14 @@ func maybeClampOOMScoreAdj(oomScoreValue int) (int, error) { } return oomScoreValue, nil } + +func getAllCDIDeviceNames(registry *cdi.Cache, c *ContainerConfig) ([]string, error) { + if len(c.GPUs) == 0 { + return c.CDIDevices, nil + } + gpuCDIDevices, err := gpusToCDIDevices(c.GPUs, registry) + if err != nil { + return nil, fmt.Errorf("converting GPU identifiers to CDI devices: %w", err) + } + return slices.Concat(c.CDIDevices, gpuCDIDevices), nil +} diff --git a/libpod/options.go b/libpod/options.go index badfe6b8c8..c4abfe4afe 100644 --- a/libpod/options.go +++ b/libpod/options.go @@ -270,6 +270,17 @@ func WithCDI(devices []string) CtrCreateOption { } } +// WithGPUs sets the devices to check for CDI configuration. +func WithGPUs(gpus []string) CtrCreateOption { + return func(ctr *Container) error { + if ctr.valid { + return define.ErrCtrFinalized + } + ctr.config.GPUs = gpus + return nil + } +} + func WithCDISpecDirs(cdiSpecDirs []string) RuntimeOption { return func(rt *Runtime) error { if rt.valid { diff --git a/libpod/util.go b/libpod/util.go index 3408e0a326..f29b7c249e 100644 --- a/libpod/util.go +++ b/libpod/util.go @@ -292,3 +292,56 @@ func isDirectory(path string) bool { } return info.IsDir() } + +// vendorLister is an interface for listing GPU vendors from CDI. +type vendorLister interface { + ListVendors() []string +} + +// discoverGPUVendorFromCDI discovers vendor from CDI cache. +// It returns the vendor domain (e.g., "nvidia.com", "amd.com") that should +// be used to construct fully qualified CDI device names. +// Returns an error if no known GPU vendor is found. +func discoverGPUVendorFromCDI(lister vendorLister) (string, error) { + if lister == nil { + return "", fmt.Errorf("vendor lister cannot be nil") + } + + knownGPUVendors := []string{ + "nvidia.com", + "amd.com", + } + vendors := lister.ListVendors() + // Check if any known GPU vendor is present + for _, knownVendor := range knownGPUVendors { + for _, vendor := range vendors { + if vendor == knownVendor { + logrus.Debugf("Discovered GPU vendor from CDI specs: %s", vendor) + return vendor, nil + } + } + } + + return "", fmt.Errorf("no known GPU vendor found in CDI specs") +} + +// gpusToCDIDevices converts GPU identifiers to full CDI device names +// by discovering the vendor from the provided vendor lister. +func gpusToCDIDevices(gpus []string, lister vendorLister) ([]string, error) { + if len(gpus) == 0 { + return nil, nil + } + + vendor, err := discoverGPUVendorFromCDI(lister) + if err != nil { + return nil, fmt.Errorf("could not discover GPU vendor: %w", err) + } + + cdiDevices := make([]string, 0, len(gpus)) + for _, gpu := range gpus { + device := fmt.Sprintf("%s/gpu=%s", vendor, gpu) + cdiDevices = append(cdiDevices, device) + logrus.Debugf("Added GPU device: %s", device) + } + return cdiDevices, nil +} diff --git a/libpod/util_test.go b/libpod/util_test.go index 2ca5b36e39..7932d034bf 100644 --- a/libpod/util_test.go +++ b/libpod/util_test.go @@ -67,3 +67,148 @@ func Test_sortMounts(t *testing.T) { }) } } + +type mockVendorLister struct { + vendors []string +} + +func (m *mockVendorLister) ListVendors() []string { + return m.vendors +} + +func Test_gpusToCDIDevices(t *testing.T) { + tests := []struct { + name string + gpus []string + vendors []string + expectError bool + expectDevices []string + }{ + { + name: "No GPUs", + gpus: []string{}, + vendors: []string{"amd.com"}, + }, + { + name: "Nil GPUs", + gpus: nil, + vendors: []string{"amd.com"}, + }, + { + name: "Nil vendors", + gpus: []string{"0"}, + vendors: nil, + expectError: true, + }, + { + name: "Single GPU with AMD", + gpus: []string{"0"}, + vendors: []string{"amd.com"}, + expectDevices: []string{"amd.com/gpu=0"}, + }, + { + name: "Multiple GPUs with AMD", + gpus: []string{"0", "1"}, + vendors: []string{"amd.com"}, + expectDevices: []string{"amd.com/gpu=0", "amd.com/gpu=1"}, + }, + { + name: "Single GPU with NVIDIA", + gpus: []string{"0"}, + vendors: []string{"nvidia.com"}, + expectDevices: []string{"nvidia.com/gpu=0"}, + }, + { + name: "Multiple GPUs with NVIDIA", + gpus: []string{"0", "1"}, + vendors: []string{"nvidia.com"}, + expectDevices: []string{"nvidia.com/gpu=0", "nvidia.com/gpu=1"}, + }, + { + name: "No vendors", + gpus: []string{"0"}, + vendors: []string{}, + expectError: true, + }, + { + name: "Unknown vendor", + gpus: []string{"0"}, + vendors: []string{"unknown.com"}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var lister vendorLister + if tt.vendors != nil { + lister = &mockVendorLister{vendors: tt.vendors} + } + cdiDevices, err := gpusToCDIDevices(tt.gpus, lister) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectDevices, cdiDevices) + } + }) + } +} + +func Test_discoverGPUVendorFromCDI(t *testing.T) { + tests := []struct { + name string + vendors []string + expectVendor string + expectError bool + }{ + { + name: "Nil vendors", + vendors: nil, + expectError: true, + }, + { + name: "NVIDIA vendor", + vendors: []string{"nvidia.com"}, + expectVendor: "nvidia.com", + }, + { + name: "AMD vendor", + vendors: []string{"amd.com"}, + expectVendor: "amd.com", + }, + { + name: "No vendors", + vendors: []string{}, + expectError: true, + }, + { + name: "Unknown vendor", + vendors: []string{"unknown.com"}, + expectError: true, + }, + { + name: "Mixed vendor", + vendors: []string{"amd.com", "nvidia.com"}, + expectVendor: "nvidia.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var lister vendorLister + if tt.vendors != nil { + lister = &mockVendorLister{vendors: tt.vendors} + } + vendor, err := discoverGPUVendorFromCDI(lister) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectVendor, vendor) + } + }) + } +} diff --git a/pkg/specgen/generate/container_create.go b/pkg/specgen/generate/container_create.go index d69b0f3f44..cfcc16e302 100644 --- a/pkg/specgen/generate/container_create.go +++ b/pkg/specgen/generate/container_create.go @@ -246,6 +246,9 @@ func MakeContainer(ctx context.Context, rt *libpod.Runtime, s *specgen.SpecGener logrus.Debugf("setting container name %s", s.Name) options = append(options, libpod.WithName(s.Name)) } + if len(s.GPUs) > 0 { + options = append(options, libpod.WithGPUs(s.GPUs)) + } if len(s.Devices) > 0 { opts = ExtractCDIDevices(s) options = append(options, opts...) diff --git a/pkg/specgen/specgen.go b/pkg/specgen/specgen.go index 767710e898..4b64f679fa 100644 --- a/pkg/specgen/specgen.go +++ b/pkg/specgen/specgen.go @@ -318,6 +318,10 @@ type ContainerStorageConfig struct { DevicesFrom []string `json:"devices_from,omitempty"` // HostDeviceList is used to recreate the mounted device on inherited containers HostDeviceList []spec.LinuxDevice `json:"host_device_list,omitempty"` + // GPUs contains GPU device identifiers for CDI resolution. + // These will be resolved to full CDI device paths on the server side. + // Optional. + GPUs []string `json:"gpus,omitempty"` // IpcNS is the container's IPC namespace. // Default is private. // Conflicts with ShmSize if not set to private. diff --git a/pkg/specgenutil/specgen.go b/pkg/specgenutil/specgen.go index 2456a9555e..f3ec108e99 100644 --- a/pkg/specgenutil/specgen.go +++ b/pkg/specgenutil/specgen.go @@ -812,12 +812,9 @@ func FillOutSpecGen(s *specgen.SpecGenerator, c *entities.ContainerCreateOptions s.ArtifactVolumes = containerMounts.artifactVolumes } - devices := c.Devices - for _, gpu := range c.GPUs { - devices = append(devices, "nvidia.com/gpu="+gpu) - } + s.GPUs = c.GPUs - for _, dev := range devices { + for _, dev := range c.Devices { s.Devices = append(s.Devices, specs.LinuxDevice{Path: dev}) }