Skip to content

Commit ab47270

Browse files
authored
Merge pull request #127 from NVIDIA/rewrite-vfio-manager-in-go
Rewrite vfio-manager script in go
2 parents c71412b + 7002214 commit ab47270

File tree

24 files changed

+43724
-3
lines changed

24 files changed

+43724
-3
lines changed

cmd/vfio-manage/bind.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
//go:build !darwin && !windows
2+
3+
/*
4+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package main
20+
21+
import (
22+
"fmt"
23+
24+
"github.com/sirupsen/logrus"
25+
"github.com/urfave/cli/v2"
26+
27+
"github.com/NVIDIA/k8s-driver-manager/internal/nvpci"
28+
)
29+
30+
type bindCommand struct {
31+
logger *logrus.Logger
32+
nvpciLib nvpci.Interface
33+
}
34+
35+
type bindOptions struct {
36+
all bool
37+
deviceID string
38+
}
39+
40+
// newBindCommand constructs a bind command with the specified logger
41+
func newBindCommand(logger *logrus.Logger) *cli.Command {
42+
c := bindCommand{
43+
logger: logger,
44+
nvpciLib: nvpci.New(),
45+
}
46+
return c.build()
47+
}
48+
49+
// build the bind command
50+
func (m bindCommand) build() *cli.Command {
51+
cfg := bindOptions{}
52+
53+
// Create the 'bind' command
54+
c := cli.Command{
55+
Name: "bind",
56+
Usage: "Bind device(s) to vfio-pci driver",
57+
Before: func(c *cli.Context) error {
58+
return m.validateFlags(&cfg)
59+
},
60+
Action: func(c *cli.Context) error {
61+
return m.run(&cfg)
62+
},
63+
Flags: []cli.Flag{
64+
&cli.BoolFlag{
65+
Name: "all",
66+
Aliases: []string{"a"},
67+
Destination: &cfg.all,
68+
Usage: "Bind all NVIDIA devices to vfio-pci",
69+
},
70+
&cli.StringFlag{
71+
Name: "device-id",
72+
Aliases: []string{"d"},
73+
Destination: &cfg.deviceID,
74+
Usage: "Specific device ID to bind (e.g., 0000:01:00.0)",
75+
},
76+
},
77+
}
78+
79+
return &c
80+
}
81+
82+
func (m bindCommand) validateFlags(cfg *bindOptions) error {
83+
if !cfg.all && cfg.deviceID == "" {
84+
return fmt.Errorf("either --all or --device-id must be specified")
85+
}
86+
87+
if cfg.all && cfg.deviceID != "" {
88+
return fmt.Errorf("cannot specify both --all and --device-id")
89+
}
90+
91+
return nil
92+
}
93+
94+
func (m bindCommand) run(cfg *bindOptions) error {
95+
if cfg.deviceID != "" {
96+
return m.bindDevice(cfg.deviceID)
97+
}
98+
99+
return m.bindAll()
100+
}
101+
102+
func (m bindCommand) bindAll() error {
103+
devices, err := m.nvpciLib.GetGPUs()
104+
if err != nil {
105+
return fmt.Errorf("failed to get NVIDIA GPUs: %w", err)
106+
}
107+
108+
for _, dev := range devices {
109+
m.logger.Infof("Binding device %s", dev.Address)
110+
// (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver()
111+
if err := m.nvpciLib.BindToVFIODriver(dev); err != nil {
112+
m.logger.Warnf("Failed to bind device %s: %v", dev.Address, err)
113+
}
114+
}
115+
116+
return nil
117+
}
118+
119+
func (m bindCommand) bindDevice(device string) error {
120+
nvdev, err := m.nvpciLib.GetGPUByPciBusID(device)
121+
if err != nil {
122+
return fmt.Errorf("failed to get NVIDIA GPU device: %w", err)
123+
}
124+
if nvdev == nil || !nvdev.IsGPU() {
125+
m.logger.Infof("Device %s is not a GPU", device)
126+
return nil
127+
}
128+
129+
m.logger.Infof("Binding device %s", device)
130+
131+
// (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver()
132+
if err := m.nvpciLib.BindToVFIODriver(nvdev); err != nil {
133+
return fmt.Errorf("failed to bind device %s to vfio driver: %w", device, err)
134+
}
135+
136+
return nil
137+
}

cmd/vfio-manage/main.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//go:build !darwin && !windows
2+
3+
/*
4+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package main
20+
21+
import (
22+
"os"
23+
24+
"github.com/sirupsen/logrus"
25+
"github.com/urfave/cli/v2"
26+
27+
"github.com/NVIDIA/k8s-driver-manager/internal/info"
28+
)
29+
30+
func main() {
31+
logger := logrus.New()
32+
logger.SetFormatter(&logrus.TextFormatter{
33+
FullTimestamp: true,
34+
DisableQuote: true,
35+
})
36+
37+
app := cli.NewApp()
38+
app.Name = "vfio-manage"
39+
app.Usage = "Manage VFIO driver binding for NVIDIA GPU devices"
40+
app.Version = info.GetVersionString()
41+
42+
app.Commands = []*cli.Command{
43+
newBindCommand(logger),
44+
newUnbindCommand(logger),
45+
}
46+
if err := app.Run(os.Args); err != nil {
47+
logger.Fatal(err)
48+
}
49+
}

cmd/vfio-manage/unbind.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//go:build !darwin && !windows
2+
3+
/*
4+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package main
20+
21+
import (
22+
"fmt"
23+
24+
"github.com/sirupsen/logrus"
25+
"github.com/urfave/cli/v2"
26+
27+
"github.com/NVIDIA/k8s-driver-manager/internal/nvpci"
28+
)
29+
30+
type unbindCommand struct {
31+
logger *logrus.Logger
32+
nvpciLib nvpci.Interface
33+
}
34+
35+
type unbindOptions struct {
36+
all bool
37+
deviceID string
38+
}
39+
40+
// newUnbindCommand constructs an unbind command with the specified logger
41+
func newUnbindCommand(logger *logrus.Logger) *cli.Command {
42+
c := unbindCommand{
43+
logger: logger,
44+
nvpciLib: nvpci.New(),
45+
}
46+
return c.build()
47+
}
48+
49+
// build the unbind command
50+
func (m unbindCommand) build() *cli.Command {
51+
cfg := unbindOptions{}
52+
53+
// Create the 'unbind' command
54+
c := cli.Command{
55+
Name: "unbind",
56+
Usage: "Unbind device(s) from their current driver",
57+
Before: func(c *cli.Context) error {
58+
return m.validateFlags(&cfg)
59+
},
60+
Action: func(c *cli.Context) error {
61+
return m.run(&cfg)
62+
},
63+
Flags: []cli.Flag{
64+
&cli.BoolFlag{
65+
Name: "all",
66+
Aliases: []string{"a"},
67+
Destination: &cfg.all,
68+
Usage: "Bind all NVIDIA devices to vfio-pci",
69+
},
70+
&cli.StringFlag{
71+
Name: "device-id",
72+
Aliases: []string{"d"},
73+
Destination: &cfg.deviceID,
74+
Usage: "Specific device ID to bind (e.g., 0000:01:00.0)",
75+
},
76+
},
77+
}
78+
79+
return &c
80+
}
81+
82+
func (m unbindCommand) validateFlags(cfg *unbindOptions) error {
83+
if !cfg.all && cfg.deviceID == "" {
84+
return fmt.Errorf("either --all or --device-id must be specified")
85+
}
86+
87+
if cfg.all && cfg.deviceID != "" {
88+
return fmt.Errorf("cannot specify both --all and --device-id")
89+
}
90+
91+
return nil
92+
}
93+
94+
func (m unbindCommand) run(cfg *unbindOptions) error {
95+
if cfg.deviceID != "" {
96+
return m.unbindDevice(cfg.deviceID)
97+
}
98+
99+
return m.unbindAll()
100+
}
101+
102+
func (m unbindCommand) unbindAll() error {
103+
devices, err := m.nvpciLib.GetGPUs()
104+
if err != nil {
105+
return fmt.Errorf("failed to get NVIDIA GPUs: %w", err)
106+
}
107+
108+
for _, dev := range devices {
109+
m.logger.Infof("Unbinding device %s", dev.Address)
110+
// (cdesiniotis) ideally this should be replaced by a call to nvdev.UnbindFromDriver()
111+
if err := m.nvpciLib.UnbindFromDriver(dev); err != nil {
112+
m.logger.Warnf("Failed to unbind device %s: %v", dev.Address, err)
113+
}
114+
}
115+
return nil
116+
}
117+
118+
func (m unbindCommand) unbindDevice(device string) error {
119+
nvdev, err := m.nvpciLib.GetGPUByPciBusID(device)
120+
if err != nil {
121+
return fmt.Errorf("failed to get NVIDIA GPU device: %w", err)
122+
}
123+
if nvdev == nil || !nvdev.IsGPU() {
124+
m.logger.Infof("Device %s is not a GPU", device)
125+
return nil
126+
}
127+
128+
m.logger.Infof("Unbinding device %s", device)
129+
130+
// (cdesiniotis) ideally this should be replaced by a call to nvdev.UnbindFromDriver()
131+
if err := m.nvpciLib.UnbindFromDriver(nvdev); err != nil {
132+
return fmt.Errorf("failed to unbind device %s from driver: %w", device, err)
133+
}
134+
135+
return nil
136+
}

deployments/container/Dockerfile.distroless

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ COPY go.mod go.mod
2323
COPY go.sum go.sum
2424
COPY vendor vendor
2525
COPY cmd/driver-manager cmd/driver-manager
26+
COPY cmd/vfio-manage cmd/vfio-manage
2627
COPY internal/ internal/
2728

2829
RUN dnf install -y wget make git gcc
@@ -49,6 +50,8 @@ ARG GIT_COMMIT="unknown"
4950
ARG VERSION_PKG="github.com/NVIDIA/k8s-driver-manager/internal/info"
5051
RUN go build -ldflags "-extldflags=-Wl,-z,lazy -s -w -X ${VERSION_PKG}.gitCommit=${GIT_COMMIT} -X ${VERSION_PKG}.version=${VERSION}" \
5152
-o driver-manager ./cmd/driver-manager/
53+
RUN go build -ldflags "-extldflags=-Wl,-z,lazy -s -w -X ${VERSION_PKG}.gitCommit=${GIT_COMMIT} -X ${VERSION_PKG}.version=${VERSION}" \
54+
-o vfio-manage ./cmd/vfio-manage/
5255

5356
ARG TARGETARCH
5457

@@ -59,7 +62,7 @@ USER 0:0
5962
SHELL ["/busybox/sh", "-c"]
6063
RUN ln -s /busybox/sh /bin/sh
6164

62-
COPY scripts/vfio-manage /usr/bin
65+
COPY --from=build /work/vfio-manage /usr/bin
6366
COPY --from=build /work/driver-manager /usr/bin
6467

6568
LABEL io.k8s.display-name="NVIDIA Driver Upgrade Manager for Kubernetes"

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module github.com/NVIDIA/k8s-driver-manager
33
go 1.24.0
44

55
require (
6+
github.com/NVIDIA/go-nvlib v0.8.1
67
github.com/moby/sys/mount v0.3.4
78
github.com/sirupsen/logrus v1.9.3
89
github.com/urfave/cli/v2 v2.27.7

go.sum

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25
22
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
33
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
44
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
5+
github.com/NVIDIA/go-nvlib v0.8.1 h1:OPEHVvn3zcV5OXB68A7WRpeCnYMRSPl7LdeJH/d3gZI=
6+
github.com/NVIDIA/go-nvlib v0.8.1/go.mod h1:7mzx9FSdO9fXWP9NKuZmWkCwhkEcSWQFe2tmFwtLb9c=
57
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
68
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
79
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
@@ -131,8 +133,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
131133
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
132134
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
133135
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
134-
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
135-
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
136+
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
137+
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
136138
github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU=
137139
github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4=
138140
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=

0 commit comments

Comments
 (0)