Skip to content

Commit fcfa6cf

Browse files
committed
wip: update vfio-manage to choose best VFIO driver
Signed-off-by: Christopher Desiniotis <[email protected]>
1 parent e5733cc commit fcfa6cf

File tree

2 files changed

+330
-9
lines changed

2 files changed

+330
-9
lines changed

internal/nvpci/modalias.go

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
package nvpci
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/sirupsen/logrus"
8+
"golang.org/x/sys/unix"
9+
)
10+
11+
// modAlias is a decomposed version of string like this
12+
//
13+
// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN
14+
//
15+
// (followed by a space or newline). The "NNNN" are always of the
16+
// length in the example unless replaced with a wildcard ("*"),
17+
// but we make no assumptions about length.
18+
type modAlias struct {
19+
vendor string // v
20+
device string // d
21+
subvendor string // sv
22+
subdevice string // sd
23+
baseClass string // bc
24+
subClass string // sc
25+
interface_ string // i
26+
}
27+
28+
// vfioAlias represents an entry from the modules.alias file for a vfio driver
29+
type vfioAlias struct {
30+
modAlias *modAlias // The modalias pattern
31+
driver string // The vfio driver name
32+
}
33+
34+
// parseModAliasString parses a modalias string in the format:
35+
// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN
36+
// where N can be hex digits or wildcards (*)
37+
func parseModAliasString(input string) (*modAlias, error) {
38+
if input == "" {
39+
return nil, fmt.Errorf("modalias string is empty")
40+
}
41+
42+
// Trim any whitespace or newlines
43+
input = strings.TrimSpace(input)
44+
45+
// Trim the leading "pci:" prefix in the modalias file
46+
split := strings.SplitN(input, ":", 2)
47+
if len(split) != 2 {
48+
return nil, fmt.Errorf("unexpected number of parts in modalias after trimming pci: prefix: %s", input)
49+
}
50+
input = split[1]
51+
52+
// Parse vendor ID (starts with 'v')
53+
if !strings.HasPrefix(input, "v") {
54+
return nil, fmt.Errorf("modalias must start with 'v', got: %s", input)
55+
}
56+
57+
ma := &modAlias{}
58+
remaining := input[1:] // skip 'v'
59+
60+
// Extract vendor
61+
vendor, remaining, err := extractField(remaining, "d")
62+
if err != nil {
63+
return nil, fmt.Errorf("failed to parse vendor: %w", err)
64+
}
65+
ma.vendor = vendor
66+
67+
// Extract device
68+
device, remaining, err := extractField(remaining, "sv")
69+
if err != nil {
70+
return nil, fmt.Errorf("failed to parse device: %w", err)
71+
}
72+
ma.device = device
73+
74+
// Extract subvendor
75+
subvendor, remaining, err := extractField(remaining, "sd")
76+
if err != nil {
77+
return nil, fmt.Errorf("failed to parse subvendor: %w", err)
78+
}
79+
ma.subvendor = subvendor
80+
81+
// Extract subdevice
82+
subdevice, remaining, err := extractField(remaining, "bc")
83+
if err != nil {
84+
return nil, fmt.Errorf("failed to parse subdevice: %w", err)
85+
}
86+
ma.subdevice = subdevice
87+
88+
// Extract base class
89+
baseClass, remaining, err := extractField(remaining, "sc")
90+
if err != nil {
91+
return nil, fmt.Errorf("failed to parse base class: %w", err)
92+
}
93+
ma.baseClass = baseClass
94+
95+
// Extract subclass
96+
subClass, remaining, err := extractField(remaining, "i")
97+
if err != nil {
98+
return nil, fmt.Errorf("failed to parse subclass: %w", err)
99+
}
100+
ma.subClass = subClass
101+
102+
// Extract interface (remaining content)
103+
ma.interface_ = remaining
104+
105+
return ma, nil
106+
}
107+
108+
// extractField extracts the value before the next delimiter from the input string.
109+
// Returns the extracted value, the remaining string (without the delimiter), and any error.
110+
func extractField(input, delimiter string) (string, string, error) {
111+
idx := strings.Index(input, delimiter)
112+
if idx == -1 {
113+
// Delimiter not found - this could be the last field
114+
return input, "", nil
115+
}
116+
117+
value := input[:idx]
118+
remaining := input[idx+len(delimiter):]
119+
120+
return value, remaining, nil
121+
}
122+
123+
func getKernelVersion() (string, error) {
124+
var uname unix.Utsname
125+
if err := unix.Uname(&uname); err != nil {
126+
return "", err
127+
}
128+
129+
// Convert C-style byte array to Go string
130+
release := make([]byte, 0, len(uname.Release))
131+
for _, c := range uname.Release {
132+
if c == 0 {
133+
break
134+
}
135+
release = append(release, byte(c))
136+
}
137+
138+
return string(release), nil
139+
}
140+
141+
// parseVFIOAliases parses the modules.alias file and extracts all entries
142+
// that contain "vfio_pci" in the driver name
143+
func parseVFIOAliases(content string) []vfioAlias {
144+
var aliases []vfioAlias
145+
146+
lines := strings.Split(content, "\n")
147+
for _, line := range lines {
148+
line = strings.TrimSpace(line)
149+
150+
if !strings.HasPrefix(line, "alias vfio_pci:") {
151+
continue
152+
}
153+
154+
split := strings.SplitN(line, " ", 3)
155+
if len(split) != 3 {
156+
logrus.Warnf("malformed modules.alias 'vfio_pci' line: %q", line)
157+
continue
158+
}
159+
modAliasStr := split[1]
160+
modAlias, err := parseModAliasString(modAliasStr)
161+
if err != nil {
162+
logrus.Warnf("failed to parse modalias string %q: %v", modAliasStr, err)
163+
continue
164+
}
165+
logrus.Infof("modalias for %s: %+v", modAliasStr, modAlias)
166+
167+
driver := split[2]
168+
aliases = append(aliases, vfioAlias{
169+
modAlias: modAlias,
170+
driver: driver,
171+
})
172+
}
173+
174+
return aliases
175+
}
176+
177+
// findBestMatch finds the best matching VFIO driver for the given modalias
178+
// by comparing against all available vfio alias patterns
179+
func findBestMatch(deviceModAlias *modAlias, aliases []vfioAlias) string {
180+
var bestDriver string
181+
bestSpecificity := -1
182+
183+
for _, alias := range aliases {
184+
if matches, specificity := matchModalias(deviceModAlias, alias.modAlias); matches {
185+
if specificity > bestSpecificity {
186+
bestDriver = alias.driver
187+
bestSpecificity = specificity
188+
}
189+
}
190+
}
191+
192+
return bestDriver
193+
}
194+
195+
// matchModalias checks if a device modalias matches a pattern from modules.alias
196+
// Returns true if it matches and a specificity score (higher is more specific)
197+
func matchModalias(deviceModAlias, patternModAlias *modAlias) (bool, int) {
198+
specificity := 0
199+
200+
// Compare each field - wildcards in pattern match anything
201+
// More specific matches (fewer wildcards) get higher specificity scores
202+
203+
if matches, score := matchField(deviceModAlias.vendor, patternModAlias.vendor); !matches {
204+
return false, 0
205+
} else {
206+
specificity += score
207+
}
208+
209+
if matches, score := matchField(deviceModAlias.device, patternModAlias.device); !matches {
210+
return false, 0
211+
} else {
212+
specificity += score
213+
}
214+
215+
if matches, score := matchField(deviceModAlias.subvendor, patternModAlias.subvendor); !matches {
216+
return false, 0
217+
} else {
218+
specificity += score
219+
}
220+
221+
if matches, score := matchField(deviceModAlias.subdevice, patternModAlias.subdevice); !matches {
222+
return false, 0
223+
} else {
224+
specificity += score
225+
}
226+
227+
if matches, score := matchField(deviceModAlias.baseClass, patternModAlias.baseClass); !matches {
228+
return false, 0
229+
} else {
230+
specificity += score
231+
}
232+
233+
if matches, score := matchField(deviceModAlias.subClass, patternModAlias.subClass); !matches {
234+
return false, 0
235+
} else {
236+
specificity += score
237+
}
238+
239+
if matches, score := matchField(deviceModAlias.interface_, patternModAlias.interface_); !matches {
240+
return false, 0
241+
} else {
242+
specificity += score
243+
}
244+
245+
return true, specificity
246+
}
247+
248+
// matchField matches a single field from the modalias
249+
// Wildcards (*) in the pattern match any value
250+
// Returns true if it matches and a score (higher for exact matches)
251+
func matchField(deviceValue, patternValue string) (bool, int) {
252+
// Wildcard in pattern matches anything
253+
if patternValue == "*" || strings.Contains(patternValue, "*") {
254+
return true, 0
255+
}
256+
257+
// Exact match
258+
if deviceValue == patternValue {
259+
return true, len(patternValue) // Score based on field length
260+
}
261+
262+
// No match
263+
return false, 0
264+
}

internal/nvpci/nvpci.go

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,18 @@ func (w *nvpciWrapper) UnbindFromDriver(dev *nvpci.NvidiaPCIDevice) error {
5454
}
5555

5656
func (d *nvidiaPCIDevice) bindToVFIODriver() error {
57-
// TODO: Instead of always binding to vfio-pci, check if a vfio variant module
58-
// should be used instead. This is required for GB200 where the nvgrace-gpu-vfio-pci
59-
// module must be used instead of vfio-pci.
60-
if d.Driver != vfioPCIDriverName {
57+
vfioDriverName, err := d.findBestVFIOVariant()
58+
if err != nil {
59+
return fmt.Errorf("failed to find best vfio variant driver: %w", err)
60+
}
61+
logrus.Infof("best vfio variant driver for %s: %s", d.Address, vfioDriverName)
62+
63+
if d.Driver != vfioDriverName {
6164
if err := unbind(d.Address); err != nil {
6265
return fmt.Errorf("failed to unbind device %s: %w", d.Address, err)
6366
}
64-
if err := bind(d.Address, vfioPCIDriverName); err != nil {
65-
return fmt.Errorf("failed to bind device %s to %s: %w", d.Address, vfioPCIDriverName, err)
67+
if err := bind(d.Address, vfioDriverName); err != nil {
68+
return fmt.Errorf("failed to bind device %s to %s: %w", d.Address, vfioDriverName, err)
6669
}
6770
}
6871

@@ -78,15 +81,15 @@ func (d *nvidiaPCIDevice) bindToVFIODriver() error {
7881
if err != nil {
7982
return fmt.Errorf("failed to get driver for graphics auxiliary device %s: %w", auxDev, err)
8083
}
81-
if auxDevDriver == vfioPCIDriverName {
84+
if auxDevDriver == vfioDriverName {
8285
return nil
8386
}
8487

8588
if err := unbind(auxDev); err != nil {
8689
return fmt.Errorf("failed to unbind graphics auxiliary device %s: %w", auxDev, err)
8790
}
88-
if err := bind(auxDev, vfioPCIDriverName); err != nil {
89-
return fmt.Errorf("failed to bind graphics auxiliary device %s to %s: %w", auxDev, vfioPCIDriverName, err)
91+
if err := bind(auxDev, vfioDriverName); err != nil {
92+
return fmt.Errorf("failed to bind graphics auxiliary device %s to %s: %w", auxDev, vfioDriverName, err)
9093
}
9194

9295
return nil
@@ -190,3 +193,57 @@ func unbind(device string) error {
190193

191194
return nil
192195
}
196+
197+
/* findBestVFIOVariant:
198+
*
199+
* Find the "best" match of all vfio_pci aliases for dev in the host
200+
* modules.alias file. This uses the algorithm of finding every
201+
* modules.alias line that begins with "vfio_pci:", then picking the
202+
* one that matches the device's own modalias value (from the file of
203+
* that name in the device's sysfs directory) with the fewest
204+
* "wildcards" (* character, meaning "match any value for this
205+
* attribute").
206+
*/
207+
func (d *nvidiaPCIDevice) findBestVFIOVariant() (string, error) {
208+
modAliasPath := filepath.Join(d.Path, "modalias")
209+
modAliasContent, err := os.ReadFile(modAliasPath)
210+
if err != nil {
211+
return "", fmt.Errorf("failed to read modalias file for %s: %w", d.Address, err)
212+
}
213+
214+
modAliasStr := strings.TrimSpace(string(modAliasContent))
215+
modAlias, err := parseModAliasString(modAliasStr)
216+
if err != nil {
217+
return "", fmt.Errorf("failed to parse modalias string %q for device %q: %w", modAliasStr, d.Address, err)
218+
}
219+
logrus.Debugf("modalias for device %q: %+v", d.Address, modAlias)
220+
221+
kernelVersion, err := getKernelVersion()
222+
if err != nil {
223+
return "", fmt.Errorf("failed to get kernel version: %w", err)
224+
}
225+
logrus.Debugf("kernel version: %s", kernelVersion)
226+
227+
modulesAliasFilePath := filepath.Join("/lib/modules", kernelVersion, "modules.alias")
228+
modulesAliasContent, err := os.ReadFile(modulesAliasFilePath)
229+
if err != nil {
230+
return "", fmt.Errorf("failed to read file %s: %w", modulesAliasFilePath, err)
231+
}
232+
233+
// Parse modules.alias and find all vfio_pci entries
234+
vfioAliases := parseVFIOAliases(string(modulesAliasContent))
235+
if len(vfioAliases) == 0 {
236+
logrus.Warnf("No vfio_pci entries found in modules.alias file, falling back to default vfio-pci driver")
237+
return vfioPCIDriverName, nil
238+
}
239+
240+
// Find the best matching VFIO driver for this device
241+
bestMatch := findBestMatch(modAlias, vfioAliases)
242+
if bestMatch == "" {
243+
logrus.Warnf("No matching vfio driver found for device %s, falling back to default vfio-pci", d.Address)
244+
return vfioPCIDriverName, nil
245+
}
246+
247+
logrus.Infof("Best VFIO driver for device %s: %s", d.Address, bestMatch)
248+
return bestMatch, nil
249+
}

0 commit comments

Comments
 (0)