Merge pull request #28008 from shiv-tyagi/vendor-detection

Discover GPU vendor from CDI spec before injecting GPU for --gpus option
This commit is contained in:
Paul Holzinger
2026-02-12 18:54:34 +01:00
committed by GitHub
9 changed files with 240 additions and 9 deletions

View File

@@ -4,5 +4,4 @@
####> are applicable to all of those.
#### **--gpus**=*ENTRY*
GPU devices to add to the container ('all' to pass all GPUs) Currently only
Nvidia devices are supported.
Start the container with GPU support. Where `ENTRY` can be `all` to request all GPUs, or a vendor-specific identifier. Currently, NVIDIA and AMD devices are supported. If both NVIDIA and AMD devices are present, the NVIDIA devices will be preferred, and a CDI device name must be specified using the `--device` flag to request a set of GPUs from a *specific* vendor.

View File

@@ -442,6 +442,8 @@ type ContainerMiscConfig struct {
PidFile string `json:"pid_file,omitempty"`
// CDIDevices contains devices that use the CDI
CDIDevices []string `json:"cdiDevices,omitempty"`
// GPUs contains gpus which eventually get resolved to CDI devices
GPUs []string `json:"gpus,omitempty"`
// DeviceHostSrc contains the original source on the host
DeviceHostSrc []spec.LinuxDevice `json:"device_host_src,omitempty"`
// EnvSecrets are secrets that are set as environment variables

View File

@@ -712,7 +712,8 @@ func (c *Container) generateSpec(ctx context.Context) (s *spec.Spec, cleanupFunc
}
// Warning: CDI may alter g.Config in place.
if len(c.config.CDIDevices) > 0 {
// GPUs are also handled via CDI.
if len(c.config.CDIDevices) > 0 || len(c.config.GPUs) > 0 {
registry, err := cdi.NewCache(
cdi.WithSpecDirs(c.runtime.config.Engine.CdiSpecDirs.Get()...),
cdi.WithAutoRefresh(false),
@@ -723,7 +724,12 @@ func (c *Container) generateSpec(ctx context.Context) (s *spec.Spec, cleanupFunc
if err := registry.Refresh(); err != nil {
logrus.Debugf("The following error was triggered when refreshing the CDI registry: %v", err)
}
if _, err := registry.InjectDevices(g.Config, c.config.CDIDevices...); err != nil {
cdiDevices, err := getAllCDIDeviceNames(registry, c.config)
if err != nil {
return nil, nil, fmt.Errorf("getting CDI device names: %w", err)
}
if _, err := registry.InjectDevices(g.Config, cdiDevices...); err != nil {
return nil, nil, fmt.Errorf("setting up CDI devices: %w", err)
}
}
@@ -3167,3 +3173,14 @@ func maybeClampOOMScoreAdj(oomScoreValue int) (int, error) {
}
return oomScoreValue, nil
}
func getAllCDIDeviceNames(registry *cdi.Cache, c *ContainerConfig) ([]string, error) {
if len(c.GPUs) == 0 {
return c.CDIDevices, nil
}
gpuCDIDevices, err := gpusToCDIDevices(c.GPUs, registry)
if err != nil {
return nil, fmt.Errorf("converting GPU identifiers to CDI devices: %w", err)
}
return slices.Concat(c.CDIDevices, gpuCDIDevices), nil
}

View File

@@ -270,6 +270,17 @@ func WithCDI(devices []string) CtrCreateOption {
}
}
// WithGPUs sets the devices to check for CDI configuration.
func WithGPUs(gpus []string) CtrCreateOption {
return func(ctr *Container) error {
if ctr.valid {
return define.ErrCtrFinalized
}
ctr.config.GPUs = gpus
return nil
}
}
func WithCDISpecDirs(cdiSpecDirs []string) RuntimeOption {
return func(rt *Runtime) error {
if rt.valid {

View File

@@ -292,3 +292,56 @@ func isDirectory(path string) bool {
}
return info.IsDir()
}
// vendorLister is an interface for listing GPU vendors from CDI.
type vendorLister interface {
ListVendors() []string
}
// discoverGPUVendorFromCDI discovers vendor from CDI cache.
// It returns the vendor domain (e.g., "nvidia.com", "amd.com") that should
// be used to construct fully qualified CDI device names.
// Returns an error if no known GPU vendor is found.
func discoverGPUVendorFromCDI(lister vendorLister) (string, error) {
if lister == nil {
return "", fmt.Errorf("vendor lister cannot be nil")
}
knownGPUVendors := []string{
"nvidia.com",
"amd.com",
}
vendors := lister.ListVendors()
// Check if any known GPU vendor is present
for _, knownVendor := range knownGPUVendors {
for _, vendor := range vendors {
if vendor == knownVendor {
logrus.Debugf("Discovered GPU vendor from CDI specs: %s", vendor)
return vendor, nil
}
}
}
return "", fmt.Errorf("no known GPU vendor found in CDI specs")
}
// gpusToCDIDevices converts GPU identifiers to full CDI device names
// by discovering the vendor from the provided vendor lister.
func gpusToCDIDevices(gpus []string, lister vendorLister) ([]string, error) {
if len(gpus) == 0 {
return nil, nil
}
vendor, err := discoverGPUVendorFromCDI(lister)
if err != nil {
return nil, fmt.Errorf("could not discover GPU vendor: %w", err)
}
cdiDevices := make([]string, 0, len(gpus))
for _, gpu := range gpus {
device := fmt.Sprintf("%s/gpu=%s", vendor, gpu)
cdiDevices = append(cdiDevices, device)
logrus.Debugf("Added GPU device: %s", device)
}
return cdiDevices, nil
}

View File

@@ -67,3 +67,148 @@ func Test_sortMounts(t *testing.T) {
})
}
}
type mockVendorLister struct {
vendors []string
}
func (m *mockVendorLister) ListVendors() []string {
return m.vendors
}
func Test_gpusToCDIDevices(t *testing.T) {
tests := []struct {
name string
gpus []string
vendors []string
expectError bool
expectDevices []string
}{
{
name: "No GPUs",
gpus: []string{},
vendors: []string{"amd.com"},
},
{
name: "Nil GPUs",
gpus: nil,
vendors: []string{"amd.com"},
},
{
name: "Nil vendors",
gpus: []string{"0"},
vendors: nil,
expectError: true,
},
{
name: "Single GPU with AMD",
gpus: []string{"0"},
vendors: []string{"amd.com"},
expectDevices: []string{"amd.com/gpu=0"},
},
{
name: "Multiple GPUs with AMD",
gpus: []string{"0", "1"},
vendors: []string{"amd.com"},
expectDevices: []string{"amd.com/gpu=0", "amd.com/gpu=1"},
},
{
name: "Single GPU with NVIDIA",
gpus: []string{"0"},
vendors: []string{"nvidia.com"},
expectDevices: []string{"nvidia.com/gpu=0"},
},
{
name: "Multiple GPUs with NVIDIA",
gpus: []string{"0", "1"},
vendors: []string{"nvidia.com"},
expectDevices: []string{"nvidia.com/gpu=0", "nvidia.com/gpu=1"},
},
{
name: "No vendors",
gpus: []string{"0"},
vendors: []string{},
expectError: true,
},
{
name: "Unknown vendor",
gpus: []string{"0"},
vendors: []string{"unknown.com"},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var lister vendorLister
if tt.vendors != nil {
lister = &mockVendorLister{vendors: tt.vendors}
}
cdiDevices, err := gpusToCDIDevices(tt.gpus, lister)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectDevices, cdiDevices)
}
})
}
}
func Test_discoverGPUVendorFromCDI(t *testing.T) {
tests := []struct {
name string
vendors []string
expectVendor string
expectError bool
}{
{
name: "Nil vendors",
vendors: nil,
expectError: true,
},
{
name: "NVIDIA vendor",
vendors: []string{"nvidia.com"},
expectVendor: "nvidia.com",
},
{
name: "AMD vendor",
vendors: []string{"amd.com"},
expectVendor: "amd.com",
},
{
name: "No vendors",
vendors: []string{},
expectError: true,
},
{
name: "Unknown vendor",
vendors: []string{"unknown.com"},
expectError: true,
},
{
name: "Mixed vendor",
vendors: []string{"amd.com", "nvidia.com"},
expectVendor: "nvidia.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var lister vendorLister
if tt.vendors != nil {
lister = &mockVendorLister{vendors: tt.vendors}
}
vendor, err := discoverGPUVendorFromCDI(lister)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectVendor, vendor)
}
})
}
}

View File

@@ -246,6 +246,9 @@ func MakeContainer(ctx context.Context, rt *libpod.Runtime, s *specgen.SpecGener
logrus.Debugf("setting container name %s", s.Name)
options = append(options, libpod.WithName(s.Name))
}
if len(s.GPUs) > 0 {
options = append(options, libpod.WithGPUs(s.GPUs))
}
if len(s.Devices) > 0 {
opts = ExtractCDIDevices(s)
options = append(options, opts...)

View File

@@ -318,6 +318,10 @@ type ContainerStorageConfig struct {
DevicesFrom []string `json:"devices_from,omitempty"`
// HostDeviceList is used to recreate the mounted device on inherited containers
HostDeviceList []spec.LinuxDevice `json:"host_device_list,omitempty"`
// GPUs contains GPU device identifiers for CDI resolution.
// These will be resolved to full CDI device paths on the server side.
// Optional.
GPUs []string `json:"gpus,omitempty"`
// IpcNS is the container's IPC namespace.
// Default is private.
// Conflicts with ShmSize if not set to private.

View File

@@ -812,12 +812,9 @@ func FillOutSpecGen(s *specgen.SpecGenerator, c *entities.ContainerCreateOptions
s.ArtifactVolumes = containerMounts.artifactVolumes
}
devices := c.Devices
for _, gpu := range c.GPUs {
devices = append(devices, "nvidia.com/gpu="+gpu)
}
s.GPUs = c.GPUs
for _, dev := range devices {
for _, dev := range c.Devices {
s.Devices = append(s.Devices, specs.LinuxDevice{Path: dev})
}