Skip to content

Commit fe47297

Browse files
committed
wip: update vfio-manage to choose best VFIO driver
Signed-off-by: Christopher Desiniotis <[email protected]>
1 parent e5733cc commit fe47297

File tree

2 files changed

+347
-9
lines changed

2 files changed

+347
-9
lines changed

internal/nvpci/modalias.go

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
package nvpci
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/sirupsen/logrus"
8+
"golang.org/x/sys/unix"
9+
)
10+
11+
// modAlias is a decomposed version of string like this
12+
//
13+
// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN
14+
//
15+
// (followed by a space or newline). The "NNNN" are always of the
16+
// length in the example unless replaced with a wildcard ("*"),
17+
// but we make no assumptions about length.
18+
type modAlias struct {
19+
vendor string // v
20+
device string // d
21+
subvendor string // sv
22+
subdevice string // sd
23+
baseClass string // bc
24+
subClass string // sc
25+
interface_ string // i
26+
}
27+
28+
// vfioAlias represents an entry from modules.alias for a vfio driver
29+
type vfioAlias struct {
30+
modAlias *modAlias // The modalias pattern
31+
driver string // The vfio driver name
32+
}
33+
34+
// parseModAlias parses a modalias string in the format:
35+
// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN
36+
// where N can be hex digits or wildcards (*)
37+
func parseModAlias(input string) (*modAlias, error) {
38+
if input == "" {
39+
return nil, fmt.Errorf("modalias string is empty")
40+
}
41+
42+
// Trim any whitespace or newlines
43+
input = strings.TrimSpace(input)
44+
45+
// Trim the leading "pci:" prefix in the modalias file
46+
split := strings.SplitN(input, ":", 2)
47+
if len(split) != 2 {
48+
return nil, fmt.Errorf("unexpected number of parts in modalias after trimming pci: prefix: %s", input)
49+
}
50+
input = split[1]
51+
52+
// Parse vendor ID (starts with 'v')
53+
if !strings.HasPrefix(input, "v") {
54+
return nil, fmt.Errorf("modalias must start with 'v', got: %s", input)
55+
}
56+
57+
ma := &modAlias{}
58+
remaining := input[1:] // skip 'v'
59+
60+
// Extract vendor
61+
vendor, remaining, err := extractField(remaining, "d")
62+
if err != nil {
63+
return nil, fmt.Errorf("failed to parse vendor: %w", err)
64+
}
65+
ma.vendor = vendor
66+
67+
// Extract device
68+
device, remaining, err := extractField(remaining, "sv")
69+
if err != nil {
70+
return nil, fmt.Errorf("failed to parse device: %w", err)
71+
}
72+
ma.device = device
73+
74+
// Extract subvendor
75+
subvendor, remaining, err := extractField(remaining, "sd")
76+
if err != nil {
77+
return nil, fmt.Errorf("failed to parse subvendor: %w", err)
78+
}
79+
ma.subvendor = subvendor
80+
81+
// Extract subdevice
82+
subdevice, remaining, err := extractField(remaining, "bc")
83+
if err != nil {
84+
return nil, fmt.Errorf("failed to parse subdevice: %w", err)
85+
}
86+
ma.subdevice = subdevice
87+
88+
// Extract base class
89+
baseClass, remaining, err := extractField(remaining, "sc")
90+
if err != nil {
91+
return nil, fmt.Errorf("failed to parse base class: %w", err)
92+
}
93+
ma.baseClass = baseClass
94+
95+
// Extract subclass
96+
subClass, remaining, err := extractField(remaining, "i")
97+
if err != nil {
98+
return nil, fmt.Errorf("failed to parse subclass: %w", err)
99+
}
100+
ma.subClass = subClass
101+
102+
// Extract interface (remaining content)
103+
ma.interface_ = remaining
104+
105+
return ma, nil
106+
}
107+
108+
// extractField extracts the value before the next delimiter from the input string.
109+
// Returns the extracted value, the remaining string (without the delimiter), and any error.
110+
func extractField(input, delimiter string) (string, string, error) {
111+
idx := strings.Index(input, delimiter)
112+
if idx == -1 {
113+
// Delimiter not found - this could be the last field
114+
return input, "", nil
115+
}
116+
117+
value := input[:idx]
118+
remaining := input[idx+len(delimiter):]
119+
120+
return value, remaining, nil
121+
}
122+
123+
func getKernelVersion() (string, error) {
124+
var uname unix.Utsname
125+
if err := unix.Uname(&uname); err != nil {
126+
return "", err
127+
}
128+
129+
// Convert C-style byte array to Go string
130+
release := make([]byte, 0, len(uname.Release))
131+
for _, c := range uname.Release {
132+
if c == 0 {
133+
break
134+
}
135+
release = append(release, byte(c))
136+
}
137+
138+
return string(release), nil
139+
}
140+
141+
// parseVFIOAliases parses the modules.alias file and extracts all entries
142+
// that contain "vfio_pci" in the driver name
143+
func parseVFIOAliases(content string) []vfioAlias {
144+
var aliases []vfioAlias
145+
146+
lines := strings.Split(content, "\n")
147+
for _, line := range lines {
148+
line = strings.TrimSpace(line)
149+
150+
// Skip empty lines and comments
151+
if line == "" || strings.HasPrefix(line, "#") {
152+
continue
153+
}
154+
155+
// Expected format: "alias <pattern> <driver>"
156+
parts := strings.Fields(line)
157+
if len(parts) < 3 {
158+
continue
159+
}
160+
161+
if parts[0] != "alias" {
162+
continue
163+
}
164+
165+
pattern := parts[1]
166+
driver := parts[2]
167+
168+
// Filter for vfio drivers
169+
if strings.HasPrefix(driver, "vfio_pci:") {
170+
modAlias, err := parseModAlias(pattern)
171+
if err != nil {
172+
logrus.Warnf("failed to parse modalias for %s: %v", pattern, err)
173+
continue
174+
}
175+
logrus.Infof("modalias for %s: %+v", pattern, modAlias)
176+
aliases = append(aliases, vfioAlias{
177+
modAlias: modAlias,
178+
driver: driver,
179+
})
180+
}
181+
}
182+
183+
return aliases
184+
}
185+
186+
// findBestMatch finds the best matching VFIO driver for the given modalias
187+
// by comparing against all available vfio alias patterns
188+
func findBestMatch(deviceModAlias *modAlias, aliases []vfioAlias) string {
189+
// Track the best match (most specific pattern)
190+
var bestDriver string
191+
bestSpecificity := -1
192+
193+
for _, alias := range aliases {
194+
if matches, specificity := matchModalias(deviceModAlias, alias.modAlias); matches {
195+
if specificity > bestSpecificity {
196+
bestDriver = alias.driver
197+
bestSpecificity = specificity
198+
}
199+
}
200+
}
201+
202+
return bestDriver
203+
}
204+
205+
// matchModalias checks if a device modalias matches a pattern from modules.alias
206+
// Returns true if it matches and a specificity score (higher is more specific)
207+
func matchModalias(deviceModAlias, patternModAlias *modAlias) (bool, int) {
208+
specificity := 0
209+
210+
// Compare each field - wildcards in pattern match anything
211+
// More specific matches (fewer wildcards) get higher specificity scores
212+
213+
// Match vendor
214+
if matches, score := matchField(deviceModAlias.vendor, patternModAlias.vendor); !matches {
215+
return false, 0
216+
} else {
217+
specificity += score
218+
}
219+
220+
// Match device
221+
if matches, score := matchField(deviceModAlias.device, patternModAlias.device); !matches {
222+
return false, 0
223+
} else {
224+
specificity += score
225+
}
226+
227+
// Match subvendor
228+
if matches, score := matchField(deviceModAlias.subvendor, patternModAlias.subvendor); !matches {
229+
return false, 0
230+
} else {
231+
specificity += score
232+
}
233+
234+
// Match subdevice
235+
if matches, score := matchField(deviceModAlias.subdevice, patternModAlias.subdevice); !matches {
236+
return false, 0
237+
} else {
238+
specificity += score
239+
}
240+
241+
// Match base class
242+
if matches, score := matchField(deviceModAlias.baseClass, patternModAlias.baseClass); !matches {
243+
return false, 0
244+
} else {
245+
specificity += score
246+
}
247+
248+
// Match subclass
249+
if matches, score := matchField(deviceModAlias.subClass, patternModAlias.subClass); !matches {
250+
return false, 0
251+
} else {
252+
specificity += score
253+
}
254+
255+
// Match interface
256+
if matches, score := matchField(deviceModAlias.interface_, patternModAlias.interface_); !matches {
257+
return false, 0
258+
} else {
259+
specificity += score
260+
}
261+
262+
return true, specificity
263+
}
264+
265+
// matchField matches a single field from the modalias
266+
// Wildcards (*) in the pattern match any value
267+
// Returns true if it matches and a score (higher for exact matches)
268+
func matchField(deviceValue, patternValue string) (bool, int) {
269+
// Wildcard in pattern matches anything
270+
if patternValue == "*" || strings.Contains(patternValue, "*") {
271+
return true, 0
272+
}
273+
274+
// Exact match
275+
if deviceValue == patternValue {
276+
return true, len(patternValue) // Score based on field length
277+
}
278+
279+
// No match
280+
return false, 0
281+
}

internal/nvpci/nvpci.go

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,18 @@ func (w *nvpciWrapper) UnbindFromDriver(dev *nvpci.NvidiaPCIDevice) error {
5454
}
5555

5656
func (d *nvidiaPCIDevice) bindToVFIODriver() error {
57-
// TODO: Instead of always binding to vfio-pci, check if a vfio variant module
58-
// should be used instead. This is required for GB200 where the nvgrace-gpu-vfio-pci
59-
// module must be used instead of vfio-pci.
60-
if d.Driver != vfioPCIDriverName {
57+
vfioDriverName, err := d.findBestVFIOVariant()
58+
if err != nil {
59+
return fmt.Errorf("failed to find best vfio variant: %w", err)
60+
}
61+
logrus.Infof("best vfio variant for %s: %s", d.Address, vfioDriverName)
62+
63+
if d.Driver != vfioDriverName {
6164
if err := unbind(d.Address); err != nil {
6265
return fmt.Errorf("failed to unbind device %s: %w", d.Address, err)
6366
}
64-
if err := bind(d.Address, vfioPCIDriverName); err != nil {
65-
return fmt.Errorf("failed to bind device %s to %s: %w", d.Address, vfioPCIDriverName, err)
67+
if err := bind(d.Address, vfioDriverName); err != nil {
68+
return fmt.Errorf("failed to bind device %s to %s: %w", d.Address, vfioDriverName, err)
6669
}
6770
}
6871

@@ -78,15 +81,15 @@ func (d *nvidiaPCIDevice) bindToVFIODriver() error {
7881
if err != nil {
7982
return fmt.Errorf("failed to get driver for graphics auxiliary device %s: %w", auxDev, err)
8083
}
81-
if auxDevDriver == vfioPCIDriverName {
84+
if auxDevDriver == vfioDriverName {
8285
return nil
8386
}
8487

8588
if err := unbind(auxDev); err != nil {
8689
return fmt.Errorf("failed to unbind graphics auxiliary device %s: %w", auxDev, err)
8790
}
88-
if err := bind(auxDev, vfioPCIDriverName); err != nil {
89-
return fmt.Errorf("failed to bind graphics auxiliary device %s to %s: %w", auxDev, vfioPCIDriverName, err)
91+
if err := bind(auxDev, vfioDriverName); err != nil {
92+
return fmt.Errorf("failed to bind graphics auxiliary device %s to %s: %w", auxDev, vfioDriverName, err)
9093
}
9194

9295
return nil
@@ -190,3 +193,57 @@ func unbind(device string) error {
190193

191194
return nil
192195
}
196+
197+
/* findBestVFIOVariant:
198+
*
199+
* Find the "best" match of all vfio_pci aliases for dev in the host
200+
* modules.alias file. This uses the algorithm of finding every
201+
* modules.alias line that begins with "vfio_pci:", then picking the
202+
* one that matches the device's own modalias value (from the file of
203+
* that name in the device's sysfs directory) with the fewest
204+
* "wildcards" (* character, meaning "match any value for this
205+
* attribute").
206+
*/
207+
func (d *nvidiaPCIDevice) findBestVFIOVariant() (string, error) {
208+
modAliasPath := filepath.Join(d.Path, "modalias")
209+
modAliasContent, err := os.ReadFile(modAliasPath)
210+
if err != nil {
211+
return "", fmt.Errorf("failed to read modalias file for %s: %w", d.Address, err)
212+
}
213+
214+
modAliasStr := strings.TrimSpace(string(modAliasContent))
215+
modAlias, err := parseModAlias(modAliasStr)
216+
if err != nil {
217+
return "", fmt.Errorf("failed to parse modalias for %s: %w", d.Address, err)
218+
}
219+
logrus.Infof("modalias for %s: %+v", d.Address, modAlias)
220+
221+
kernelVersion, err := getKernelVersion()
222+
if err != nil {
223+
return "", fmt.Errorf("failed to get kernel version: %w", err)
224+
}
225+
logrus.Infof("kernel version: %s", kernelVersion)
226+
227+
modulesAliasFilePath := filepath.Join("/lib/modules", kernelVersion, "modules.alias")
228+
modulesAliasContent, err := os.ReadFile(modulesAliasFilePath)
229+
if err != nil {
230+
return "", fmt.Errorf("failed to read modules.alias file: %w", err)
231+
}
232+
233+
// Parse modules.alias and find all vfio_pci entries
234+
vfioAliases := parseVFIOAliases(string(modulesAliasContent))
235+
if len(vfioAliases) == 0 {
236+
logrus.Warnf("No vfio_pci entries found in modules.alias, falling back to default vfio-pci")
237+
return vfioPCIDriverName, nil
238+
}
239+
240+
// Find the best matching VFIO driver for this device
241+
bestMatch := findBestMatch(modAlias, vfioAliases)
242+
if bestMatch == "" {
243+
logrus.Warnf("No matching vfio driver found for device %s, falling back to default vfio-pci", d.Address)
244+
return vfioPCIDriverName, nil
245+
}
246+
247+
logrus.Infof("Best VFIO driver for device %s: %s", d.Address, bestMatch)
248+
return bestMatch, nil
249+
}

0 commit comments

Comments
 (0)