@@ -8,10 +8,10 @@ import (
88 "strings"
99 "syscall"
1010
11- log "github.com/sirupsen/logrus"
1211 "github.com/urfave/cli/v2"
1312 "golang.org/x/sys/unix"
1413
14+ "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
1515 "github.com/NVIDIA/nvidia-container-toolkit/tools/container/runtime"
1616 "github.com/NVIDIA/nvidia-container-toolkit/tools/container/toolkit"
1717)
@@ -51,12 +51,40 @@ func (o options) toolkitRoot() string {
5151var Version = "development"
5252
5353func main () {
54- remainingArgs , root , err := ParseArgs (os .Args )
54+ logger := logger .New ()
55+
56+ remainingArgs , root , err := ParseArgs (logger , os .Args )
5557 if err != nil {
56- log .Errorf ("Error: unable to parse arguments: %v" , err )
58+ logger .Errorf ("Error: unable to parse arguments: %v" , err )
59+ os .Exit (1 )
60+ }
61+
62+ c := new (logger , root )
63+
64+ // Run the CLI
65+ logger .Infof ("Starting %v" , c .Name )
66+ if err := c .Run (remainingArgs ); err != nil {
67+ logger .Errorf ("error running nvidia-toolkit: %v" , err )
5768 os .Exit (1 )
5869 }
5970
71+ logger .Infof ("Completed %v" , c .Name )
72+ }
73+
74+ type app struct {
75+ logger logger.Interface
76+ defaultRoot string
77+ }
78+
79+ func new (logger logger.Interface , defaultRoot string ) * cli.App {
80+ a := app {
81+ logger : logger ,
82+ defaultRoot : defaultRoot ,
83+ }
84+ return a .build ()
85+ }
86+
87+ func (a app ) build () * cli.App {
6088 options := options {
6189 toolkitOptions : toolkit.Options {},
6290 }
@@ -68,10 +96,10 @@ func main() {
6896 c .Description = "DESTINATION points to the host path underneath which the nvidia-container-toolkit should be installed.\n It will be installed at ${DESTINATION}/toolkit"
6997 c .Version = Version
7098 c .Before = func (ctx * cli.Context ) error {
71- return validateFlags (ctx , & options )
99+ return a . validateFlags (ctx , & options )
72100 }
73101 c .Action = func (ctx * cli.Context ) error {
74- return Run (ctx , & options )
102+ return a . Run (ctx , & options )
75103 }
76104
77105 // Setup flags for the CLI
@@ -102,7 +130,7 @@ func main() {
102130 },
103131 & cli.StringFlag {
104132 Name : "root" ,
105- Value : root ,
133+ Value : a . defaultRoot ,
106134 Usage : "the folder where the NVIDIA Container Toolkit is to be installed. It will be installed to `ROOT`/toolkit" ,
107135 Destination : & options .root ,
108136 EnvVars : []string {"ROOT" },
@@ -119,17 +147,10 @@ func main() {
119147 c .Flags = append (c .Flags , toolkit .Flags (& options .toolkitOptions )... )
120148 c .Flags = append (c .Flags , runtime .Flags (& options .runtimeOptions )... )
121149
122- // Run the CLI
123- log .Infof ("Starting %v" , c .Name )
124- if err := c .Run (remainingArgs ); err != nil {
125- log .Errorf ("error running nvidia-toolkit: %v" , err )
126- os .Exit (1 )
127- }
128-
129- log .Infof ("Completed %v" , c .Name )
150+ return c
130151}
131152
132- func validateFlags (_ * cli.Context , o * options ) error {
153+ func ( a * app ) validateFlags (_ * cli.Context , o * options ) error {
133154 if o .root == "" {
134155 return fmt .Errorf ("the install root must be specified" )
135156 }
@@ -139,6 +160,7 @@ func validateFlags(_ *cli.Context, o *options) error {
139160 if filepath .Base (o .pidFile ) != toolkitPidFilename {
140161 return fmt .Errorf ("invalid toolkit.pid path %v" , o .pidFile )
141162 }
163+
142164 if err := toolkit .ValidateOptions (& o .toolkitOptions , o .toolkitRoot ()); err != nil {
143165 return err
144166 }
@@ -149,12 +171,12 @@ func validateFlags(_ *cli.Context, o *options) error {
149171}
150172
151173// Run runs the core logic of the CLI
152- func Run (c * cli.Context , o * options ) error {
153- err := initialize (o .pidFile )
174+ func ( a * app ) Run (c * cli.Context , o * options ) error {
175+ err := a . initialize (o .pidFile )
154176 if err != nil {
155177 return fmt .Errorf ("unable to initialize: %v" , err )
156178 }
157- defer shutdown (o .pidFile )
179+ defer a . shutdown (o .pidFile )
158180
159181 if len (o .toolkitOptions .ContainerRuntimeRuntimes .Value ()) == 0 {
160182 lowlevelRuntimePaths , err := runtime .GetLowlevelRuntimePaths (& o .runtimeOptions , o .runtime )
@@ -176,7 +198,7 @@ func Run(c *cli.Context, o *options) error {
176198 }
177199
178200 if ! o .noDaemon {
179- err = waitForSignal ()
201+ err = a . waitForSignal ()
180202 if err != nil {
181203 return fmt .Errorf ("unable to wait for signal: %v" , err )
182204 }
@@ -192,8 +214,8 @@ func Run(c *cli.Context, o *options) error {
192214
193215// ParseArgs checks if a single positional argument was defined and extracts this the root.
194216// If no positional arguments are defined, it is assumed that the root is specified as a flag.
195- func ParseArgs (args []string ) ([]string , string , error ) {
196- log .Infof ("Parsing arguments" )
217+ func ParseArgs (logger logger. Interface , args []string ) ([]string , string , error ) {
218+ logger .Infof ("Parsing arguments" )
197219
198220 if len (args ) < 2 {
199221 return args , "" , nil
@@ -218,8 +240,8 @@ func ParseArgs(args []string) ([]string, string, error) {
218240 return nil , "" , fmt .Errorf ("unexpected positional argument(s) %v" , args [2 :lastPositionalArg + 1 ])
219241}
220242
221- func initialize (pidFile string ) error {
222- log .Infof ("Initializing" )
243+ func ( a * app ) initialize (pidFile string ) error {
244+ a . logger .Infof ("Initializing" )
223245
224246 if dir := filepath .Dir (pidFile ); dir != "" {
225247 err := os .MkdirAll (dir , 0755 )
@@ -235,8 +257,8 @@ func initialize(pidFile string) error {
235257
236258 err = unix .Flock (int (f .Fd ()), unix .LOCK_EX | unix .LOCK_NB )
237259 if err != nil {
238- log .Warningf ("Unable to get exclusive lock on '%v'" , pidFile )
239- log .Warningf ("This normally means an instance of the NVIDIA toolkit Container is already running, aborting" )
260+ a . logger .Warningf ("Unable to get exclusive lock on '%v'" , pidFile )
261+ a . logger .Warningf ("This normally means an instance of the NVIDIA toolkit Container is already running, aborting" )
240262 return fmt .Errorf ("unable to get flock on pidfile: %v" , err )
241263 }
242264
@@ -253,27 +275,27 @@ func initialize(pidFile string) error {
253275 case <- waitingForSignal :
254276 signalReceived <- true
255277 default :
256- log .Infof ("Signal received, exiting early" )
257- shutdown (pidFile )
278+ a . logger .Infof ("Signal received, exiting early" )
279+ a . shutdown (pidFile )
258280 os .Exit (0 )
259281 }
260282 }()
261283
262284 return nil
263285}
264286
265- func waitForSignal () error {
266- log .Infof ("Waiting for signal" )
287+ func ( a * app ) waitForSignal () error {
288+ a . logger .Infof ("Waiting for signal" )
267289 waitingForSignal <- true
268290 <- signalReceived
269291 return nil
270292}
271293
272- func shutdown (pidFile string ) {
273- log .Infof ("Shutting Down" )
294+ func ( a * app ) shutdown (pidFile string ) {
295+ a . logger .Infof ("Shutting Down" )
274296
275297 err := os .Remove (pidFile )
276298 if err != nil {
277- log .Warningf ("Unable to remove pidfile: %v" , err )
299+ a . logger .Warningf ("Unable to remove pidfile: %v" , err )
278300 }
279301}
0 commit comments