Skip to content

Commit 20cc763

Browse files
committed
refactor vfio-manage subcommands
Signed-off-by: Christopher Desiniotis <[email protected]>
1 parent 589b001 commit 20cc763

File tree

4 files changed

+285
-167
lines changed

4 files changed

+285
-167
lines changed

cmd/vfio-manage/bind.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
//go:build !darwin && !windows
2+
3+
/*
4+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package main
20+
21+
import (
22+
"fmt"
23+
24+
"github.com/sirupsen/logrus"
25+
"github.com/urfave/cli/v2"
26+
27+
"github.com/NVIDIA/k8s-driver-manager/internal/nvpci"
28+
)
29+
30+
type bindCommand struct {
31+
logger *logrus.Logger
32+
nvpciLib nvpci.Interface
33+
}
34+
35+
type bindOptions struct {
36+
all bool
37+
deviceID string
38+
}
39+
40+
// newBindCommand constructs a bind command with the specified logger
41+
func newBindCommand(logger *logrus.Logger) *cli.Command {
42+
c := bindCommand{
43+
logger: logger,
44+
nvpciLib: nvpci.New(),
45+
}
46+
return c.build()
47+
}
48+
49+
// build the bind command
50+
func (m bindCommand) build() *cli.Command {
51+
cfg := bindOptions{}
52+
53+
// Create the 'bind' command
54+
c := cli.Command{
55+
Name: "bind",
56+
Usage: "Bind device(s) to vfio-pci driver",
57+
Before: func(c *cli.Context) error {
58+
return m.validateFlags(&cfg)
59+
},
60+
Action: func(c *cli.Context) error {
61+
return m.run(&cfg)
62+
},
63+
Flags: []cli.Flag{
64+
&cli.BoolFlag{
65+
Name: "all",
66+
Aliases: []string{"a"},
67+
Destination: &cfg.all,
68+
Usage: "Bind all NVIDIA devices to vfio-pci",
69+
},
70+
&cli.StringFlag{
71+
Name: "device-id",
72+
Aliases: []string{"d"},
73+
Destination: &cfg.deviceID,
74+
Usage: "Specific device ID to bind (e.g., 0000:01:00.0)",
75+
},
76+
},
77+
}
78+
79+
return &c
80+
}
81+
82+
func (m bindCommand) validateFlags(cfg *bindOptions) error {
83+
if !cfg.all && cfg.deviceID == "" {
84+
return fmt.Errorf("either --all or --device-id must be specified")
85+
}
86+
87+
if cfg.all && cfg.deviceID != "" {
88+
return fmt.Errorf("cannot specify both --all and --device-id")
89+
}
90+
91+
return nil
92+
}
93+
94+
func (m bindCommand) run(cfg *bindOptions) error {
95+
if cfg.deviceID != "" {
96+
return m.bindDevice(cfg.deviceID)
97+
}
98+
99+
return m.bindAll()
100+
}
101+
102+
func (m bindCommand) bindAll() error {
103+
devices, err := m.nvpciLib.GetGPUs()
104+
if err != nil {
105+
return fmt.Errorf("failed to get NVIDIA GPUs: %w", err)
106+
}
107+
108+
for _, dev := range devices {
109+
m.logger.Infof("Binding device %s", dev.Address)
110+
// (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver()
111+
if err := m.nvpciLib.BindToVFIODriver(dev); err != nil {
112+
m.logger.Warnf("Failed to bind device %s: %v", dev.Address, err)
113+
}
114+
}
115+
116+
return nil
117+
}
118+
119+
func (m bindCommand) bindDevice(device string) error {
120+
nvdev, err := m.nvpciLib.GetGPUByPciBusID(device)
121+
if err != nil {
122+
return fmt.Errorf("failed to get NVIDIA GPU device: %w", err)
123+
}
124+
if nvdev == nil || !nvdev.IsGPU() {
125+
m.logger.Infof("Device %s is not a GPU", device)
126+
return nil
127+
}
128+
129+
m.logger.Infof("Binding device %s", device)
130+
131+
// (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver()
132+
if err := m.nvpciLib.BindToVFIODriver(nvdev); err != nil {
133+
return fmt.Errorf("failed to bind device %s to vfio driver: %w", device, err)
134+
}
135+
136+
return nil
137+
}

cmd/vfio-manage/main.go

Lines changed: 5 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,17 @@
1919
package main
2020

2121
import (
22-
"fmt"
2322
"os"
2423

2524
"github.com/sirupsen/logrus"
2625
"github.com/urfave/cli/v2"
2726

2827
"github.com/NVIDIA/k8s-driver-manager/internal/info"
29-
"github.com/NVIDIA/k8s-driver-manager/internal/nvpci"
3028
)
3129

32-
type flags struct {
33-
allDevices bool
34-
deviceID string
35-
}
36-
3730
func main() {
38-
flags := flags{}
39-
40-
log := logrus.New()
41-
log.SetFormatter(&logrus.TextFormatter{
31+
logger := logrus.New()
32+
logger.SetFormatter(&logrus.TextFormatter{
4233
FullTimestamp: true,
4334
DisableQuote: true,
4435
})
@@ -49,162 +40,10 @@ func main() {
4940
app.Version = info.GetVersionString()
5041

5142
app.Commands = []*cli.Command{
52-
{
53-
Name: "bind",
54-
Usage: "Bind device(s) to vfio-pci driver",
55-
Flags: []cli.Flag{
56-
&cli.BoolFlag{
57-
Name: "all",
58-
Aliases: []string{"a"},
59-
Destination: &flags.allDevices,
60-
Usage: "Bind all NVIDIA devices to vfio-pci",
61-
},
62-
&cli.StringFlag{
63-
Name: "device-id",
64-
Aliases: []string{"d"},
65-
Destination: &flags.deviceID,
66-
Usage: "Specific device ID to bind (e.g., 0000:01:00.0)",
67-
},
68-
},
69-
Before: func(c *cli.Context) error {
70-
return validateFlags(&flags)
71-
},
72-
Action: func(c *cli.Context) error {
73-
return handleBind(log, &flags)
74-
},
75-
},
76-
{
77-
Name: "unbind",
78-
Usage: "Unbind device(s) from their current driver",
79-
Flags: []cli.Flag{
80-
&cli.BoolFlag{
81-
Name: "all",
82-
Aliases: []string{"a"},
83-
Destination: &flags.allDevices,
84-
Usage: "Unbind all NVIDIA devices",
85-
},
86-
&cli.StringFlag{
87-
Name: "device-id",
88-
Aliases: []string{"d"},
89-
Destination: &flags.deviceID,
90-
Usage: "Specific device ID to unbind (e.g., 0000:01:00.0)",
91-
},
92-
},
93-
Before: func(c *cli.Context) error {
94-
return validateFlags(&flags)
95-
},
96-
Action: func(c *cli.Context) error {
97-
return handleUnbind(log, &flags)
98-
},
99-
},
43+
newBindCommand(logger),
44+
newUnbindCommand(logger),
10045
}
101-
10246
if err := app.Run(os.Args); err != nil {
103-
log.Fatal(err)
104-
}
105-
}
106-
107-
func validateFlags(flags *flags) error {
108-
if !flags.allDevices && flags.deviceID == "" {
109-
return fmt.Errorf("either --all or --device-id must be specified")
110-
}
111-
112-
if flags.allDevices && flags.deviceID != "" {
113-
return fmt.Errorf("cannot specify both --all and --device-id")
114-
}
115-
116-
return nil
117-
}
118-
119-
func handleBind(log *logrus.Logger, flags *flags) error {
120-
if flags.deviceID != "" {
121-
return bindDevice(flags.deviceID, log)
122-
}
123-
124-
return bindAll(log)
125-
}
126-
127-
func handleUnbind(log *logrus.Logger, flags *flags) error {
128-
if flags.deviceID != "" {
129-
return unbindDevice(flags.deviceID, log)
130-
}
131-
132-
return unbindAll(log)
133-
}
134-
135-
func bindAll(log *logrus.Logger) error {
136-
nvpciLib := nvpci.New()
137-
devices, err := nvpciLib.GetGPUs()
138-
if err != nil {
139-
return fmt.Errorf("failed to get NVIDIA GPUs: %w", err)
140-
}
141-
142-
for _, dev := range devices {
143-
log.Infof("Binding device %s", dev.Address)
144-
// (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver()
145-
if err := nvpciLib.BindToVFIODriver(dev); err != nil {
146-
log.Warnf("Failed to bind device %s: %v", dev.Address, err)
147-
}
47+
logger.Fatal(err)
14848
}
149-
150-
return nil
151-
}
152-
153-
func unbindAll(log *logrus.Logger) error {
154-
nvpciLib := nvpci.New()
155-
devices, err := nvpciLib.GetGPUs()
156-
if err != nil {
157-
return fmt.Errorf("failed to get NVIDIA GPUs: %w", err)
158-
}
159-
160-
for _, dev := range devices {
161-
log.Infof("Unbinding device %s", dev.Address)
162-
// (cdesiniotis) ideally this should be replaced by a call to nvdev.UnbindFromDriver()
163-
if err := nvpciLib.UnbindFromDriver(dev); err != nil {
164-
log.Warnf("Failed to unbind device %s: %v", dev.Address, err)
165-
}
166-
}
167-
return nil
168-
}
169-
170-
func bindDevice(device string, log *logrus.Logger) error {
171-
nvpciLib := nvpci.New()
172-
nvdev, err := nvpciLib.GetGPUByPciBusID(device)
173-
if err != nil {
174-
return fmt.Errorf("failed to get NVIDIA GPU device: %w", err)
175-
}
176-
if nvdev == nil || !nvdev.IsGPU() {
177-
log.Infof("Device %s is not a GPU", device)
178-
return nil
179-
}
180-
181-
log.Infof("Binding device %s", device)
182-
183-
// (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver()
184-
if err := nvpciLib.BindToVFIODriver(nvdev); err != nil {
185-
return fmt.Errorf("failed to bind device %s to vfio driver: %w", device, err)
186-
}
187-
188-
return nil
189-
}
190-
191-
func unbindDevice(device string, log *logrus.Logger) error {
192-
nvpciLib := nvpci.New()
193-
nvdev, err := nvpciLib.GetGPUByPciBusID(device)
194-
if err != nil {
195-
return fmt.Errorf("failed to get NVIDIA GPU device: %w", err)
196-
}
197-
if nvdev == nil || !nvdev.IsGPU() {
198-
log.Infof("Device %s is not a GPU", device)
199-
return nil
200-
}
201-
202-
log.Infof("Unbinding device %s", device)
203-
204-
// (cdesiniotis) ideally this should be replaced by a call to nvdev.UnbindFromDriver()
205-
if err := nvpciLib.UnbindFromDriver(nvdev); err != nil {
206-
return fmt.Errorf("failed to unbind device %s from driver: %w", device, err)
207-
}
208-
209-
return nil
21049
}

0 commit comments

Comments
 (0)