@@ -18,123 +18,133 @@ package main
1818
1919import (
2020 "fmt"
21- "io"
2221 "os"
2322 "path/filepath"
2423 "sort"
25- "strings"
2624
2725 log "github.com/sirupsen/logrus"
2826)
2927
3028type executableTarget struct {
31- dotfileName string
3229 wrapperName string
3330}
3431
3532type executable struct {
36- source string
37- target executableTarget
38- env map [string ]string
39- preLines []string
40- argLines []string
33+ source string
34+ target executableTarget
35+ argv []string
36+ envm map [string ]string
4137}
4238
4339// install installs an executable component of the NVIDIA container toolkit. The source executable
4440// is copied to a `.real` file and a wapper is created to set up the environment as required.
4541func (e executable ) install (destFolder string ) (string , error ) {
42+ if destFolder == "" {
43+ return "" , fmt .Errorf ("destination folder must be specified" )
44+ }
45+ if e .source == "" {
46+ return "" , fmt .Errorf ("source executable must be specified" )
47+ }
4648 log .Infof ("Installing executable '%v' to %v" , e .source , destFolder )
47-
48- dotfileName := e .dotfileName ()
49-
50- installedDotfileName , err := installFileToFolderWithName (destFolder , dotfileName , e .source )
49+ dotRealFilename := e .dotRealFilename ()
50+ dotRealPath , err := installFileToFolderWithName (destFolder , dotRealFilename , e .source )
5151 if err != nil {
52- return "" , fmt .Errorf ("error installing file '%v' as '%v': %v" , e .source , dotfileName , err )
52+ return "" , fmt .Errorf ("error installing file '%v' as '%v': %v" , e .source , dotRealFilename , err )
5353 }
54- log .Infof ("Installed '%v'" , installedDotfileName )
54+ log .Infof ("Installed '%v'" , dotRealPath )
5555
56- wrapperFilename , err := e .installWrapper (destFolder , installedDotfileName )
56+ wrapperPath , err := e .installWrapper (destFolder )
5757 if err != nil {
58- return "" , fmt .Errorf ("error wrapping '%v' : %v" , installedDotfileName , err )
58+ return "" , fmt .Errorf ("error installing wrapper : %v" , err )
5959 }
60- log .Infof ("Installed wrapper '%v'" , wrapperFilename )
61-
62- return wrapperFilename , nil
60+ log .Infof ("Installed wrapper '%v'" , wrapperPath )
61+ return wrapperPath , nil
6362}
6463
65- func (e executable ) dotfileName () string {
66- return e .target . dotfileName
64+ func (e executable ) dotRealFilename () string {
65+ return e .wrapperName () + ".real"
6766}
6867
6968func (e executable ) wrapperName () string {
69+ if e .target .wrapperName == "" {
70+ return filepath .Base (e .source )
71+ }
7072 return e .target .wrapperName
7173}
7274
73- func (e executable ) installWrapper (destFolder string , dotfileName string ) (string , error ) {
74- wrapperPath := filepath .Join (destFolder , e .wrapperName ())
75- wrapper , err := os .Create (wrapperPath )
75+ func (e executable ) installWrapper (destFolder string ) (string , error ) {
76+ currentExe , err := os .Executable ()
7677 if err != nil {
77- return "" , fmt .Errorf ("error creating executable wrapper : %v" , err )
78+ return "" , fmt .Errorf ("error getting current executable : %v" , err )
7879 }
79- defer wrapper .Close ()
80-
81- err = e .writeWrapperTo (wrapper , destFolder , dotfileName )
80+ src := filepath .Join (filepath .Dir (currentExe ), "wrapper" )
81+ wrapperPath , err := installFileToFolderWithName (destFolder , e .wrapperName (), src )
8282 if err != nil {
83- return "" , fmt .Errorf ("error writing wrapper contents: %v" , err )
83+ return "" , fmt .Errorf ("error installing wrapper program: %v" , err )
84+ }
85+ err = e .writeWrapperArgv (wrapperPath , destFolder )
86+ if err != nil {
87+ return "" , fmt .Errorf ("error writing wrapper argv: %v" , err )
88+ }
89+ err = e .writeWrapperEnvv (wrapperPath , destFolder )
90+ if err != nil {
91+ return "" , fmt .Errorf ("error writing wrapper envv: %v" , err )
8492 }
85-
8693 err = ensureExecutable (wrapperPath )
8794 if err != nil {
8895 return "" , fmt .Errorf ("error making wrapper executable: %v" , err )
8996 }
9097 return wrapperPath , nil
9198}
9299
93- func (e executable ) writeWrapperTo (wrapper io.Writer , destFolder string , dotfileName string ) error {
100+ func (e executable ) writeWrapperArgv (wrapperPath string , destFolder string ) error {
101+ if e .argv == nil {
102+ return nil
103+ }
94104 r := newReplacements (destDirPattern , destFolder )
95-
96- // Add the shebang
97- fmt .Fprintln (wrapper , "#! /bin/sh" )
98-
99- // Add the preceding lines if any
100- for _ , line := range e .preLines {
101- fmt .Fprintf (wrapper , "%s\n " , r .apply (line ))
105+ f , err := os .OpenFile (wrapperPath + ".argv" , os .O_RDWR | os .O_CREATE | os .O_TRUNC , 0440 )
106+ if err != nil {
107+ return err
102108 }
103-
104- // Update the path to include the destination folder
105- var env map [string ]string
106- if e .env == nil {
107- env = make (map [string ]string )
108- } else {
109- env = e .env
109+ defer f .Close ()
110+ for _ , arg := range e .argv {
111+ fmt .Fprintf (f , "%s\n " , r .apply (arg ))
110112 }
113+ return nil
114+ }
111115
112- path , specified := env ["PATH" ]
113- if ! specified {
114- path = "$PATH"
116+ func (e executable ) writeWrapperEnvv (wrapperPath string , destFolder string ) error {
117+ r := newReplacements (destDirPattern , destFolder )
118+ f , err := os .OpenFile (wrapperPath + ".envv" , os .O_RDWR | os .O_CREATE | os .O_TRUNC , 0440 )
119+ if err != nil {
120+ return err
115121 }
116- env [ "PATH" ] = strings . Join ([] string { destFolder , path }, ":" )
122+ defer f . Close ( )
117123
118- var sortedEnvvars []string
119- for e := range env {
120- sortedEnvvars = append (sortedEnvvars , e )
124+ // Update PATH to insert the destination folder at the head.
125+ var envm map [string ]string
126+ if e .envm == nil {
127+ envm = make (map [string ]string )
128+ } else {
129+ envm = e .envm
121130 }
122- sort .Strings (sortedEnvvars )
123-
124- for _ , e := range sortedEnvvars {
125- v := env [e ]
126- fmt .Fprintf (wrapper , "%s=%s \\ \n " , e , r .apply (v ))
131+ if path , ok := envm ["PATH" ]; ok {
132+ envm ["PATH" ] = destFolder + ":" + path
133+ } else {
134+ // Replace PATH with <PATH, which instructs wrapper to insert the value at the head of a
135+ // colon-separated environment variable list.
136+ delete (envm , "PATH" )
137+ envm ["<PATH" ] = destFolder
127138 }
128- // Add the call to the target executable
129- fmt .Fprintf (wrapper , "%s \\ \n " , dotfileName )
130139
131- // Insert additional lines in the `arg` list
132- for _ , line := range e .argLines {
133- fmt .Fprintf (wrapper , "\t %s \\ \n " , r .apply (line ))
140+ var envv []string
141+ for k , v := range envm {
142+ envv = append (envv , k + "=" + r .apply (v ))
143+ }
144+ sort .Strings (envv )
145+ for _ , e := range envv {
146+ fmt .Fprintf (f , "%s\n " , e )
134147 }
135- // Add the script arguments "$@"
136- fmt .Fprintln (wrapper , "\t \" $@\" " )
137-
138148 return nil
139149}
140150
0 commit comments