|
| 1 | +package nvpci |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "strings" |
| 6 | + |
| 7 | + "github.com/sirupsen/logrus" |
| 8 | + "golang.org/x/sys/unix" |
| 9 | +) |
| 10 | + |
| 11 | +// modAlias is a decomposed version of string like this |
| 12 | +// |
| 13 | +// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN |
| 14 | +// |
| 15 | +// (followed by a space or newline). The "NNNN" are always of the |
| 16 | +// length in the example unless replaced with a wildcard ("*"), |
| 17 | +// but we make no assumptions about length. |
| 18 | +type modAlias struct { |
| 19 | + vendor string // v |
| 20 | + device string // d |
| 21 | + subvendor string // sv |
| 22 | + subdevice string // sd |
| 23 | + baseClass string // bc |
| 24 | + subClass string // sc |
| 25 | + interface_ string // i |
| 26 | +} |
| 27 | + |
| 28 | +// vfioAlias represents an entry from modules.alias for a vfio driver |
| 29 | +type vfioAlias struct { |
| 30 | + modAlias *modAlias // The modalias pattern |
| 31 | + driver string // The vfio driver name |
| 32 | +} |
| 33 | + |
| 34 | +// parseModAlias parses a modalias string in the format: |
| 35 | +// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN |
| 36 | +// where N can be hex digits or wildcards (*) |
| 37 | +func parseModAlias(input string) (*modAlias, error) { |
| 38 | + if input == "" { |
| 39 | + return nil, fmt.Errorf("modalias string is empty") |
| 40 | + } |
| 41 | + |
| 42 | + // Trim any whitespace or newlines |
| 43 | + input = strings.TrimSpace(input) |
| 44 | + |
| 45 | + // Trim the leading "pci:" prefix in the modalias file |
| 46 | + split := strings.SplitN(input, ":", 2) |
| 47 | + if len(split) != 2 { |
| 48 | + return nil, fmt.Errorf("unexpected number of parts in modalias after trimming pci: prefix: %s", input) |
| 49 | + } |
| 50 | + input = split[1] |
| 51 | + |
| 52 | + // Parse vendor ID (starts with 'v') |
| 53 | + if !strings.HasPrefix(input, "v") { |
| 54 | + return nil, fmt.Errorf("modalias must start with 'v', got: %s", input) |
| 55 | + } |
| 56 | + |
| 57 | + ma := &modAlias{} |
| 58 | + remaining := input[1:] // skip 'v' |
| 59 | + |
| 60 | + // Extract vendor |
| 61 | + vendor, remaining, err := extractField(remaining, "d") |
| 62 | + if err != nil { |
| 63 | + return nil, fmt.Errorf("failed to parse vendor: %w", err) |
| 64 | + } |
| 65 | + ma.vendor = vendor |
| 66 | + |
| 67 | + // Extract device |
| 68 | + device, remaining, err := extractField(remaining, "sv") |
| 69 | + if err != nil { |
| 70 | + return nil, fmt.Errorf("failed to parse device: %w", err) |
| 71 | + } |
| 72 | + ma.device = device |
| 73 | + |
| 74 | + // Extract subvendor |
| 75 | + subvendor, remaining, err := extractField(remaining, "sd") |
| 76 | + if err != nil { |
| 77 | + return nil, fmt.Errorf("failed to parse subvendor: %w", err) |
| 78 | + } |
| 79 | + ma.subvendor = subvendor |
| 80 | + |
| 81 | + // Extract subdevice |
| 82 | + subdevice, remaining, err := extractField(remaining, "bc") |
| 83 | + if err != nil { |
| 84 | + return nil, fmt.Errorf("failed to parse subdevice: %w", err) |
| 85 | + } |
| 86 | + ma.subdevice = subdevice |
| 87 | + |
| 88 | + // Extract base class |
| 89 | + baseClass, remaining, err := extractField(remaining, "sc") |
| 90 | + if err != nil { |
| 91 | + return nil, fmt.Errorf("failed to parse base class: %w", err) |
| 92 | + } |
| 93 | + ma.baseClass = baseClass |
| 94 | + |
| 95 | + // Extract subclass |
| 96 | + subClass, remaining, err := extractField(remaining, "i") |
| 97 | + if err != nil { |
| 98 | + return nil, fmt.Errorf("failed to parse subclass: %w", err) |
| 99 | + } |
| 100 | + ma.subClass = subClass |
| 101 | + |
| 102 | + // Extract interface (remaining content) |
| 103 | + ma.interface_ = remaining |
| 104 | + |
| 105 | + return ma, nil |
| 106 | +} |
| 107 | + |
| 108 | +// extractField extracts the value before the next delimiter from the input string. |
| 109 | +// Returns the extracted value, the remaining string (without the delimiter), and any error. |
| 110 | +func extractField(input, delimiter string) (string, string, error) { |
| 111 | + idx := strings.Index(input, delimiter) |
| 112 | + if idx == -1 { |
| 113 | + // Delimiter not found - this could be the last field |
| 114 | + return input, "", nil |
| 115 | + } |
| 116 | + |
| 117 | + value := input[:idx] |
| 118 | + remaining := input[idx+len(delimiter):] |
| 119 | + |
| 120 | + return value, remaining, nil |
| 121 | +} |
| 122 | + |
| 123 | +func getKernelVersion() (string, error) { |
| 124 | + var uname unix.Utsname |
| 125 | + if err := unix.Uname(&uname); err != nil { |
| 126 | + return "", err |
| 127 | + } |
| 128 | + |
| 129 | + // Convert C-style byte array to Go string |
| 130 | + release := make([]byte, 0, len(uname.Release)) |
| 131 | + for _, c := range uname.Release { |
| 132 | + if c == 0 { |
| 133 | + break |
| 134 | + } |
| 135 | + release = append(release, byte(c)) |
| 136 | + } |
| 137 | + |
| 138 | + return string(release), nil |
| 139 | +} |
| 140 | + |
| 141 | +// parseVFIOAliases parses the modules.alias file and extracts all entries |
| 142 | +// that contain "vfio_pci" in the driver name |
| 143 | +func parseVFIOAliases(content string) []vfioAlias { |
| 144 | + var aliases []vfioAlias |
| 145 | + |
| 146 | + lines := strings.Split(content, "\n") |
| 147 | + for _, line := range lines { |
| 148 | + line = strings.TrimSpace(line) |
| 149 | + |
| 150 | + // Skip empty lines and comments |
| 151 | + if line == "" || strings.HasPrefix(line, "#") { |
| 152 | + continue |
| 153 | + } |
| 154 | + |
| 155 | + // Expected format: "alias <pattern> <driver>" |
| 156 | + parts := strings.Fields(line) |
| 157 | + if len(parts) < 3 { |
| 158 | + continue |
| 159 | + } |
| 160 | + |
| 161 | + if parts[0] != "alias" { |
| 162 | + continue |
| 163 | + } |
| 164 | + |
| 165 | + pattern := parts[1] |
| 166 | + driver := parts[2] |
| 167 | + |
| 168 | + // Filter for vfio drivers |
| 169 | + if strings.HasPrefix(driver, "vfio_pci:") { |
| 170 | + modAlias, err := parseModAlias(pattern) |
| 171 | + if err != nil { |
| 172 | + logrus.Warnf("failed to parse modalias for %s: %v", pattern, err) |
| 173 | + continue |
| 174 | + } |
| 175 | + logrus.Infof("modalias for %s: %+v", pattern, modAlias) |
| 176 | + aliases = append(aliases, vfioAlias{ |
| 177 | + modAlias: modAlias, |
| 178 | + driver: driver, |
| 179 | + }) |
| 180 | + } |
| 181 | + } |
| 182 | + |
| 183 | + return aliases |
| 184 | +} |
| 185 | + |
| 186 | +// findBestMatch finds the best matching VFIO driver for the given modalias |
| 187 | +// by comparing against all available vfio alias patterns |
| 188 | +func findBestMatch(deviceModAlias *modAlias, aliases []vfioAlias) string { |
| 189 | + // Track the best match (most specific pattern) |
| 190 | + var bestDriver string |
| 191 | + bestSpecificity := -1 |
| 192 | + |
| 193 | + for _, alias := range aliases { |
| 194 | + if matches, specificity := matchModalias(deviceModAlias, alias.modAlias); matches { |
| 195 | + if specificity > bestSpecificity { |
| 196 | + bestDriver = alias.driver |
| 197 | + bestSpecificity = specificity |
| 198 | + } |
| 199 | + } |
| 200 | + } |
| 201 | + |
| 202 | + return bestDriver |
| 203 | +} |
| 204 | + |
| 205 | +// matchModalias checks if a device modalias matches a pattern from modules.alias |
| 206 | +// Returns true if it matches and a specificity score (higher is more specific) |
| 207 | +func matchModalias(deviceModAlias, patternModAlias *modAlias) (bool, int) { |
| 208 | + specificity := 0 |
| 209 | + |
| 210 | + // Compare each field - wildcards in pattern match anything |
| 211 | + // More specific matches (fewer wildcards) get higher specificity scores |
| 212 | + |
| 213 | + // Match vendor |
| 214 | + if matches, score := matchField(deviceModAlias.vendor, patternModAlias.vendor); !matches { |
| 215 | + return false, 0 |
| 216 | + } else { |
| 217 | + specificity += score |
| 218 | + } |
| 219 | + |
| 220 | + // Match device |
| 221 | + if matches, score := matchField(deviceModAlias.device, patternModAlias.device); !matches { |
| 222 | + return false, 0 |
| 223 | + } else { |
| 224 | + specificity += score |
| 225 | + } |
| 226 | + |
| 227 | + // Match subvendor |
| 228 | + if matches, score := matchField(deviceModAlias.subvendor, patternModAlias.subvendor); !matches { |
| 229 | + return false, 0 |
| 230 | + } else { |
| 231 | + specificity += score |
| 232 | + } |
| 233 | + |
| 234 | + // Match subdevice |
| 235 | + if matches, score := matchField(deviceModAlias.subdevice, patternModAlias.subdevice); !matches { |
| 236 | + return false, 0 |
| 237 | + } else { |
| 238 | + specificity += score |
| 239 | + } |
| 240 | + |
| 241 | + // Match base class |
| 242 | + if matches, score := matchField(deviceModAlias.baseClass, patternModAlias.baseClass); !matches { |
| 243 | + return false, 0 |
| 244 | + } else { |
| 245 | + specificity += score |
| 246 | + } |
| 247 | + |
| 248 | + // Match subclass |
| 249 | + if matches, score := matchField(deviceModAlias.subClass, patternModAlias.subClass); !matches { |
| 250 | + return false, 0 |
| 251 | + } else { |
| 252 | + specificity += score |
| 253 | + } |
| 254 | + |
| 255 | + // Match interface |
| 256 | + if matches, score := matchField(deviceModAlias.interface_, patternModAlias.interface_); !matches { |
| 257 | + return false, 0 |
| 258 | + } else { |
| 259 | + specificity += score |
| 260 | + } |
| 261 | + |
| 262 | + return true, specificity |
| 263 | +} |
| 264 | + |
| 265 | +// matchField matches a single field from the modalias |
| 266 | +// Wildcards (*) in the pattern match any value |
| 267 | +// Returns true if it matches and a score (higher for exact matches) |
| 268 | +func matchField(deviceValue, patternValue string) (bool, int) { |
| 269 | + // Wildcard in pattern matches anything |
| 270 | + if patternValue == "*" || strings.Contains(patternValue, "*") { |
| 271 | + return true, 0 |
| 272 | + } |
| 273 | + |
| 274 | + // Exact match |
| 275 | + if deviceValue == patternValue { |
| 276 | + return true, len(patternValue) // Score based on field length |
| 277 | + } |
| 278 | + |
| 279 | + // No match |
| 280 | + return false, 0 |
| 281 | +} |
0 commit comments