Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
167 changes: 167 additions & 0 deletions cmd/gpu-mockctl/commands/cdi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package commands

import (
"context"
"fmt"
"os"
"path/filepath"

"github.com/urfave/cli/v3"

gpuconfig "github.com/NVIDIA/k8s-test-infra/cmd/gpu-mockctl/config"
"github.com/NVIDIA/k8s-test-infra/cmd/gpu-mockctl/internal/logger"
"github.com/NVIDIA/k8s-test-infra/pkg/gpu/cdi"
"github.com/NVIDIA/k8s-test-infra/pkg/gpu/mockdriver"
"github.com/NVIDIA/k8s-test-infra/pkg/gpu/mocktopo"
)

// NewCDICommand creates the 'cdi' subcommand
func NewCDICommand(cfg *gpuconfig.Config) *cli.Command {
return &cli.Command{
Name: "cdi",
Usage: "Generate mock driver tree and CDI specification",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "driver-root",
Usage: "host mock driver tree root",
Value: cfg.DriverRoot,
Destination: &cfg.DriverRoot,
},
&cli.StringFlag{
Name: "cdi-output",
Usage: "CDI spec output path",
Value: cfg.CDIOutput,
Destination: &cfg.CDIOutput,
},
&cli.BoolFlag{
Name: "with-dri",
Usage: "include DRI render node",
Destination: &cfg.WithDRI,
},
&cli.BoolFlag{
Name: "with-hook",
Usage: "include CDI hook references",
Destination: &cfg.WithHook,
},
&cli.StringFlag{
Name: "toolkit-root",
Usage: "toolkit root for hook paths",
Value: cfg.ToolkitRoot,
Destination: &cfg.ToolkitRoot,
},
},
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
// Validate configuration
if err := cfg.ValidateCDI(); err != nil {
return ctx, err
}
return ctx, nil
},
Action: func(ctx context.Context, cmd *cli.Command) error {
log := getLogger(cmd)
return runCDI(cfg, log)
},
}
}

func runCDI(cfg *gpuconfig.Config, log logger.Interface) error {
log.Infof("Generating CDI specification for machine: %s", cfg.Machine)
log.Debugf("Driver root: %s", cfg.DriverRoot)
log.Debugf("CDI output: %s", cfg.CDIOutput)

// Get topology
topo, err := mocktopo.New(cfg.Machine)
if err != nil {
if os.Getenv("ALLOW_UNSUPPORTED") == "true" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to understand why this is useful? Does having this option not complicate the implementation? Why not just default to the one machine type that we have and figure out how to expose new ones as they become available?

log.Warningf("Using fallback mock for CDI generation")
topo = mocktopo.NewFallback(8, "NVIDIA A100-SXM4-40GB")
} else {
return fmt.Errorf("failed to create topology: %w", err)
}
}

gpuCount := len(topo.GPUs)
log.Debugf("GPU count: %d", gpuCount)

// Create mock driver tree
log.Debugf("Writing mock driver files to %s", cfg.DriverRoot)
files := mockdriver.DefaultFiles(cfg.DriverRoot)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These files should be INTERNAL to the driver.

if err := mockdriver.WriteAll(files); err != nil {
return fmt.Errorf("failed to write driver files: %w", err)
}
log.Infof("Mock driver tree written to %s", cfg.DriverRoot)

// Create device nodes
if err := createDeviceNodes(cfg, gpuCount, log); err != nil {
// Log warnings but don't fail - device nodes might already exist
log.Warningf("Device node creation had errors: %v", err)
}

// Generate CDI spec using nvidia-container-toolkit nvcdi library
log.Debugf("Generating CDI specification")
cdiOpts := cdi.Options{
NVMLLib: topo.NVMLInterface(),
DriverRoot: cfg.DriverRoot,
DevRoot: "/host/dev", // DevRoot is already prefixed by the DaemonSet mount
NVIDIACDIHookPath: filepath.Join(cfg.ToolkitRoot, "bin/nvidia-cdi-hook"),
}

specYAML, err := cdi.Generate(cdiOpts)
if err != nil {
return fmt.Errorf("failed to generate CDI spec: %w", err)
}

// Validate before writing
log.Debugf("Validating CDI specification")
if err := cdi.Validate(specYAML); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SPEC is validated on save. We should not be validating this on our own -- this should either be done in the nvcdi package or upstream.

return fmt.Errorf("CDI spec validation failed: %w", err)
}

// Write CDI spec
log.Debugf("Writing CDI spec to %s", cfg.CDIOutput)
if err := os.MkdirAll(filepath.Dir(cfg.CDIOutput), 0o755); err != nil {
return fmt.Errorf("failed to create CDI directory: %w", err)
}

if err := os.WriteFile(cfg.CDIOutput, specYAML, 0o644); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe that this is how we output the CDI spec in the tooling that we release. See https://github.com/NVIDIA/nvidia-container-toolkit/blob/5bb032d60486da9b441a208f225f911efbad35f2/cmd/nvidia-ctk/cdi/generate/generate.go#L332

return fmt.Errorf("failed to write CDI spec: %w", err)
}

log.Infof("CDI spec written to %s (generated via nvidia-container-toolkit)",
cfg.CDIOutput)
return nil
}

func createDeviceNodes(cfg *gpuconfig.Config, gpuCount int, log logger.Interface) error {
// Create device nodes (both host /dev and under driverRoot/dev)
// Host /dev nodes for CDI runtime compatibility
log.Debugf("Creating host device nodes in /dev")
hostDevNodes := mockdriver.DeviceNodes("/dev", gpuCount, cfg.WithDRI)
if err := mockdriver.WriteAll(hostDevNodes); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the mockdriver not have a CreateDeviceNodes function instead of having to keep track of the mapping here?

log.Warningf("Failed to create host /dev nodes: %v", err)
// Don't return error, continue with driver root nodes
}

// Also create under driverRoot/dev for completeness
log.Debugf("Creating device nodes under %s/dev", cfg.DriverRoot)
driverDevNodes := mockdriver.DeviceNodes(cfg.DriverRoot, gpuCount, cfg.WithDRI)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we create these under two locations?

if err := mockdriver.WriteAll(driverDevNodes); err != nil {
log.Warningf("Failed to create %s/dev nodes: %v", cfg.DriverRoot, err)
// Don't return error as nodes might already exist
}

return nil
}
217 changes: 217 additions & 0 deletions cmd/gpu-mockctl/commands/cdi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package commands

import (
"bytes"
"context"
"log"
"os"
"path/filepath"
"strings"
"testing"

gpuconfig "github.com/NVIDIA/k8s-test-infra/cmd/gpu-mockctl/config"
"github.com/NVIDIA/k8s-test-infra/cmd/gpu-mockctl/internal/logger"
)

func TestNewCDICommand(t *testing.T) {
cfg := gpuconfig.NewDefault()
cmd := NewCDICommand(cfg)

// Test basic properties
if cmd.Name != "cdi" {
t.Errorf("Expected name=cdi, got %s", cmd.Name)
}

// Test flags
expectedFlags := map[string]bool{
"driver-root": false,
"cdi-output": false,
"with-dri": false,
"with-hook": false,
"toolkit-root": false,
}

for _, f := range cmd.Flags {
name := f.Names()[0]
if _, expected := expectedFlags[name]; expected {
expectedFlags[name] = true
}
}

for flag, found := range expectedFlags {
if !found {
t.Errorf("Expected flag %s not found", flag)
}
}
}

func TestRunCDI(t *testing.T) {
// Note: This test may fail with "operation not permitted" errors
// when trying to create device nodes with mknod.
// This is expected behavior when not running as root.

// Skip this test if not running as root
if os.Getuid() != 0 {
t.Skip("Skipping test that requires root privileges for mknod")
}

// Create a temporary directory for testing
tmpDir, err := os.MkdirTemp("", "gpu-mockctl-cdi-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Logf("Failed to remove temp dir: %v", err)
}
}()

// Capture log output
var buf bytes.Buffer
log.SetOutput(&buf)
defer log.SetOutput(os.Stderr)

tests := []struct {
name string
cfg *gpuconfig.Config
env map[string]string
wantErr bool
wantLog []string
}{
{
name: "valid dgxa100",
cfg: &gpuconfig.Config{
DriverRoot: filepath.Join(tmpDir, "driver"),
CDIOutput: filepath.Join(tmpDir, "cdi/nvidia.yaml"),
Machine: "dgxa100",
ToolkitRoot: "/usr/local/nvidia-container-toolkit",
},
wantErr: false,
wantLog: []string{
"Generating CDI specification for machine: dgxa100",
"Mock driver tree written to",
"CDI spec written to",
},
},
{
name: "with DRI nodes",
cfg: &gpuconfig.Config{
DriverRoot: filepath.Join(tmpDir, "driver-dri"),
CDIOutput: filepath.Join(tmpDir, "cdi-dri/nvidia.yaml"),
Machine: "dgxa100",
WithDRI: true,
ToolkitRoot: "/usr/local/nvidia-container-toolkit",
},
wantErr: false,
wantLog: []string{
"Generating CDI specification for machine: dgxa100",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set environment variables
for k, v := range tt.env {
if err := os.Setenv(k, v); err != nil {
t.Fatalf("Failed to set env %s: %v", k, err)
}
defer func(key string) {
if err := os.Unsetenv(key); err != nil {
t.Logf("Failed to unset env %s: %v", key, err)
}
}(k)
}

buf.Reset()
testLogger := logger.New("test", false)
err := runCDI(tt.cfg, testLogger)

if (err != nil) != tt.wantErr {
t.Errorf("runCDI() error = %v, wantErr %v", err, tt.wantErr)
}

// Check log output
logOutput := buf.String()
for _, want := range tt.wantLog {
if !strings.Contains(logOutput, want) {
t.Errorf("Expected log to contain %q, got:\n%s", want, logOutput)
}
}

// If successful, check that files were created
if !tt.wantErr && err == nil {
// Check that CDI spec was created
if _, err := os.Stat(tt.cfg.CDIOutput); os.IsNotExist(err) {
t.Errorf("Expected CDI spec at %s to exist", tt.cfg.CDIOutput)
}
}
})
}
}

func TestCDICommandValidation(t *testing.T) {
tests := []struct {
name string
cfg *gpuconfig.Config
wantErr bool
}{
{
name: "valid config",
cfg: &gpuconfig.Config{
DriverRoot: "/var/lib/nvidia-mock/driver",
CDIOutput: "/etc/cdi/nvidia.yaml",
ToolkitRoot: "/usr/local/nvidia-container-toolkit",
Machine: "dgxa100",
},
wantErr: false,
},
{
name: "empty driver root",
cfg: &gpuconfig.Config{
DriverRoot: "",
CDIOutput: "/etc/cdi/nvidia.yaml",
ToolkitRoot: "/usr/local/nvidia-container-toolkit",
Machine: "dgxa100",
},
wantErr: true,
},
{
name: "empty cdi output",
cfg: &gpuconfig.Config{
DriverRoot: "/var/lib/nvidia-mock/driver",
CDIOutput: "",
ToolkitRoot: "/usr/local/nvidia-container-toolkit",
Machine: "dgxa100",
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := NewCDICommand(tt.cfg)

// The Before hook should validate
ctx := context.Background()
_, err := cmd.Before(ctx, cmd)

if (err != nil) != tt.wantErr {
t.Errorf("Before() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
Loading
Loading