Skip to content

Commit b5da7d9

Browse files
committed
Update vfio-manage to choose best VFIO driver
Signed-off-by: Christopher Desiniotis <[email protected]>
1 parent ab47270 commit b5da7d9

File tree

2 files changed

+347
-9
lines changed

2 files changed

+347
-9
lines changed

internal/nvpci/modalias.go

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nvpci
18+
19+
import (
20+
"fmt"
21+
"strings"
22+
23+
"github.com/sirupsen/logrus"
24+
"golang.org/x/sys/unix"
25+
)
26+
27+
// modAlias is a decomposed version of string like this
28+
//
29+
// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN
30+
//
31+
// (followed by a space or newline). The "NNNN" are always of the
32+
// length in the example unless replaced with a wildcard ("*"),
33+
// but we make no assumptions about length.
34+
type modAlias struct {
35+
vendor string // v
36+
device string // d
37+
subvendor string // sv
38+
subdevice string // sd
39+
baseClass string // bc
40+
subClass string // sc
41+
interface_ string // i
42+
}
43+
44+
// vfioAlias represents an entry from the modules.alias file for a vfio driver
45+
type vfioAlias struct {
46+
modAlias *modAlias // The modalias pattern
47+
driver string // The vfio driver name
48+
}
49+
50+
// parseModAliasString parses a modalias string in the format:
51+
// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN
52+
// where N can be hex digits or wildcards (*)
53+
func parseModAliasString(input string) (*modAlias, error) {
54+
if input == "" {
55+
return nil, fmt.Errorf("modalias string is empty")
56+
}
57+
58+
// Trim any whitespace or newlines
59+
input = strings.TrimSpace(input)
60+
61+
// Trim the leading "pci:" prefix in the modalias file
62+
split := strings.SplitN(input, ":", 2)
63+
if len(split) != 2 {
64+
return nil, fmt.Errorf("unexpected number of parts in modalias after trimming pci: prefix: %s", input)
65+
}
66+
input = split[1]
67+
68+
// Parse vendor ID (starts with 'v')
69+
if !strings.HasPrefix(input, "v") {
70+
return nil, fmt.Errorf("modalias must start with 'v', got: %s", input)
71+
}
72+
73+
ma := &modAlias{}
74+
remaining := input[1:] // skip 'v'
75+
76+
// Extract vendor
77+
vendor, remaining, err := extractField(remaining, "d")
78+
if err != nil {
79+
return nil, fmt.Errorf("failed to parse vendor: %w", err)
80+
}
81+
ma.vendor = vendor
82+
83+
// Extract device
84+
device, remaining, err := extractField(remaining, "sv")
85+
if err != nil {
86+
return nil, fmt.Errorf("failed to parse device: %w", err)
87+
}
88+
ma.device = device
89+
90+
// Extract subvendor
91+
subvendor, remaining, err := extractField(remaining, "sd")
92+
if err != nil {
93+
return nil, fmt.Errorf("failed to parse subvendor: %w", err)
94+
}
95+
ma.subvendor = subvendor
96+
97+
// Extract subdevice
98+
subdevice, remaining, err := extractField(remaining, "bc")
99+
if err != nil {
100+
return nil, fmt.Errorf("failed to parse subdevice: %w", err)
101+
}
102+
ma.subdevice = subdevice
103+
104+
// Extract base class
105+
baseClass, remaining, err := extractField(remaining, "sc")
106+
if err != nil {
107+
return nil, fmt.Errorf("failed to parse base class: %w", err)
108+
}
109+
ma.baseClass = baseClass
110+
111+
// Extract subclass
112+
subClass, remaining, err := extractField(remaining, "i")
113+
if err != nil {
114+
return nil, fmt.Errorf("failed to parse subclass: %w", err)
115+
}
116+
ma.subClass = subClass
117+
118+
// Extract interface (remaining content)
119+
ma.interface_ = remaining
120+
121+
return ma, nil
122+
}
123+
124+
// extractField extracts the value before the next delimiter from the input string.
125+
// Returns the extracted value, the remaining string (without the delimiter), and any error.
126+
func extractField(input, delimiter string) (string, string, error) {
127+
idx := strings.Index(input, delimiter)
128+
if idx == -1 {
129+
// Delimiter not found - this could be the last field
130+
return input, "", nil
131+
}
132+
133+
value := input[:idx]
134+
remaining := input[idx+len(delimiter):]
135+
136+
return value, remaining, nil
137+
}
138+
139+
func getKernelVersion() (string, error) {
140+
var uname unix.Utsname
141+
if err := unix.Uname(&uname); err != nil {
142+
return "", err
143+
}
144+
145+
// Convert C-style byte array to Go string
146+
release := make([]byte, 0, len(uname.Release))
147+
for _, c := range uname.Release {
148+
if c == 0 {
149+
break
150+
}
151+
release = append(release, byte(c))
152+
}
153+
154+
return string(release), nil
155+
}
156+
157+
// parseVFIOAliases parses the modules.alias file and extracts all entries
158+
// that contain "vfio_pci" in the driver name
159+
func parseVFIOAliases(content string) []vfioAlias {
160+
var aliases []vfioAlias
161+
162+
lines := strings.Split(content, "\n")
163+
for _, line := range lines {
164+
line = strings.TrimSpace(line)
165+
166+
if !strings.HasPrefix(line, "alias vfio_pci:") {
167+
continue
168+
}
169+
170+
split := strings.SplitN(line, " ", 3)
171+
if len(split) != 3 {
172+
logrus.Warnf("malformed modules.alias 'vfio_pci' line: %q", line)
173+
continue
174+
}
175+
modAliasStr := split[1]
176+
modAlias, err := parseModAliasString(modAliasStr)
177+
if err != nil {
178+
logrus.Warnf("failed to parse modalias string %q: %v", modAliasStr, err)
179+
continue
180+
}
181+
logrus.Infof("modalias for %s: %+v", modAliasStr, modAlias)
182+
183+
driver := split[2]
184+
aliases = append(aliases, vfioAlias{
185+
modAlias: modAlias,
186+
driver: driver,
187+
})
188+
}
189+
190+
return aliases
191+
}
192+
193+
// findBestMatch finds the best matching VFIO driver for the given modalias
194+
// by comparing against all available vfio alias patterns
195+
func findBestMatch(deviceModAlias *modAlias, aliases []vfioAlias) string {
196+
var bestDriver string
197+
bestSpecificity := -1
198+
199+
for _, alias := range aliases {
200+
if matches, specificity := matchModalias(deviceModAlias, alias.modAlias); matches {
201+
if specificity > bestSpecificity {
202+
bestDriver = alias.driver
203+
bestSpecificity = specificity
204+
}
205+
}
206+
}
207+
208+
return bestDriver
209+
}
210+
211+
// matchModalias checks if a device modalias matches a pattern from modules.alias
212+
// Returns true if it matches and a specificity score (higher is more specific)
213+
func matchModalias(deviceModAlias, patternModAlias *modAlias) (bool, int) {
214+
specificity := 0
215+
216+
// Compare each field - wildcards in pattern match anything
217+
// More specific matches (fewer wildcards) get higher specificity scores
218+
219+
if matches, score := matchField(deviceModAlias.vendor, patternModAlias.vendor); !matches {
220+
return false, 0
221+
} else {
222+
specificity += score
223+
}
224+
225+
if matches, score := matchField(deviceModAlias.device, patternModAlias.device); !matches {
226+
return false, 0
227+
} else {
228+
specificity += score
229+
}
230+
231+
if matches, score := matchField(deviceModAlias.subvendor, patternModAlias.subvendor); !matches {
232+
return false, 0
233+
} else {
234+
specificity += score
235+
}
236+
237+
if matches, score := matchField(deviceModAlias.subdevice, patternModAlias.subdevice); !matches {
238+
return false, 0
239+
} else {
240+
specificity += score
241+
}
242+
243+
if matches, score := matchField(deviceModAlias.baseClass, patternModAlias.baseClass); !matches {
244+
return false, 0
245+
} else {
246+
specificity += score
247+
}
248+
249+
if matches, score := matchField(deviceModAlias.subClass, patternModAlias.subClass); !matches {
250+
return false, 0
251+
} else {
252+
specificity += score
253+
}
254+
255+
if matches, score := matchField(deviceModAlias.interface_, patternModAlias.interface_); !matches {
256+
return false, 0
257+
} else {
258+
specificity += score
259+
}
260+
261+
return true, specificity
262+
}
263+
264+
// matchField matches a single field from the modalias
265+
// Wildcards (*) in the pattern match any value
266+
// Returns true if it matches and a score (higher for exact matches)
267+
func matchField(deviceValue, patternValue string) (bool, int) {
268+
// Wildcard in pattern matches anything
269+
if patternValue == "*" || strings.Contains(patternValue, "*") {
270+
return true, 0
271+
}
272+
273+
// Exact match
274+
if deviceValue == patternValue {
275+
return true, len(patternValue) // Score based on field length
276+
}
277+
278+
// No match
279+
return false, 0
280+
}

internal/nvpci/nvpci.go

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"strings"
2424

2525
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
26+
"github.com/sirupsen/logrus"
2627
)
2728

2829
const (
@@ -74,15 +75,18 @@ func (w *nvpciWrapper) UnbindFromDriver(dev *nvpci.NvidiaPCIDevice) error {
7475
}
7576

7677
func (d *nvidiaPCIDevice) bindToVFIODriver() error {
77-
// TODO: Instead of always binding to vfio-pci, check if a vfio variant module
78-
// should be used instead. This is required for GB200 where the nvgrace-gpu-vfio-pci
79-
// module must be used instead of vfio-pci.
80-
if d.Driver != vfioPCIDriverName {
78+
vfioDriverName, err := d.findBestVFIOVariant()
79+
if err != nil {
80+
return fmt.Errorf("failed to find best vfio variant driver: %w", err)
81+
}
82+
logrus.Infof("best vfio variant driver for %s: %s", d.Address, vfioDriverName)
83+
84+
if d.Driver != vfioDriverName {
8185
if err := unbind(d.Address); err != nil {
8286
return fmt.Errorf("failed to unbind device %s: %w", d.Address, err)
8387
}
84-
if err := bind(d.Address, vfioPCIDriverName); err != nil {
85-
return fmt.Errorf("failed to bind device %s to %s: %w", d.Address, vfioPCIDriverName, err)
88+
if err := bind(d.Address, vfioDriverName); err != nil {
89+
return fmt.Errorf("failed to bind device %s to %s: %w", d.Address, vfioDriverName, err)
8690
}
8791
}
8892

@@ -94,15 +98,15 @@ func (d *nvidiaPCIDevice) bindToVFIODriver() error {
9498
if auxDev == nil {
9599
return nil
96100
}
97-
if auxDev.Driver == vfioPCIDriverName {
101+
if auxDev.Driver == vfioDriverName {
98102
return nil
99103
}
100104

101105
if err := unbind(auxDev.Address); err != nil {
102106
return fmt.Errorf("failed to unbind graphics auxiliary device %s: %w", auxDev.Address, err)
103107
}
104-
if err := bind(auxDev.Address, vfioPCIDriverName); err != nil {
105-
return fmt.Errorf("failed to bind graphics auxiliary device %s to %s: %w", auxDev.Address, vfioPCIDriverName, err)
108+
if err := bind(auxDev.Address, vfioDriverName); err != nil {
109+
return fmt.Errorf("failed to bind graphics auxiliary device %s to %s: %w", auxDev, vfioDriverName, err)
106110
}
107111

108112
return nil
@@ -218,3 +222,57 @@ func unbind(device string) error {
218222

219223
return nil
220224
}
225+
226+
/* findBestVFIOVariant:
227+
*
228+
* Find the "best" match of all vfio_pci aliases for dev in the host
229+
* modules.alias file. This uses the algorithm of finding every
230+
* modules.alias line that begins with "vfio_pci:", then picking the
231+
* one that matches the device's own modalias value (from the file of
232+
* that name in the device's sysfs directory) with the fewest
233+
* "wildcards" (* character, meaning "match any value for this
234+
* attribute").
235+
*/
236+
func (d *nvidiaPCIDevice) findBestVFIOVariant() (string, error) {
237+
modAliasPath := filepath.Join(d.Path, "modalias")
238+
modAliasContent, err := os.ReadFile(modAliasPath)
239+
if err != nil {
240+
return "", fmt.Errorf("failed to read modalias file for %s: %w", d.Address, err)
241+
}
242+
243+
modAliasStr := strings.TrimSpace(string(modAliasContent))
244+
modAlias, err := parseModAliasString(modAliasStr)
245+
if err != nil {
246+
return "", fmt.Errorf("failed to parse modalias string %q for device %q: %w", modAliasStr, d.Address, err)
247+
}
248+
logrus.Debugf("modalias for device %q: %+v", d.Address, modAlias)
249+
250+
kernelVersion, err := getKernelVersion()
251+
if err != nil {
252+
return "", fmt.Errorf("failed to get kernel version: %w", err)
253+
}
254+
logrus.Debugf("kernel version: %s", kernelVersion)
255+
256+
modulesAliasFilePath := filepath.Join("/lib/modules", kernelVersion, "modules.alias")
257+
modulesAliasContent, err := os.ReadFile(modulesAliasFilePath)
258+
if err != nil {
259+
return "", fmt.Errorf("failed to read file %s: %w", modulesAliasFilePath, err)
260+
}
261+
262+
// Parse modules.alias and find all vfio_pci entries
263+
vfioAliases := parseVFIOAliases(string(modulesAliasContent))
264+
if len(vfioAliases) == 0 {
265+
logrus.Warnf("No vfio_pci entries found in modules.alias file, falling back to default vfio-pci driver")
266+
return vfioPCIDriverName, nil
267+
}
268+
269+
// Find the best matching VFIO driver for this device
270+
bestMatch := findBestMatch(modAlias, vfioAliases)
271+
if bestMatch == "" {
272+
logrus.Warnf("No matching vfio driver found for device %s in modules.alias file, falling back to default vfio-pci driver", d.Address)
273+
return vfioPCIDriverName, nil
274+
}
275+
276+
logrus.Infof("Best VFIO driver for device %s: %s", d.Address, bestMatch)
277+
return bestMatch, nil
278+
}

0 commit comments

Comments
 (0)