@@ -18,6 +18,8 @@ package containerd
1818
1919import (
2020 "fmt"
21+ "os"
22+ "path/filepath"
2123
2224 "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine"
2325 "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/toml"
@@ -123,12 +125,38 @@ func (c *Config) EnableCDI() {
123125 * c .Tree = config
124126}
125127
126- // RemoveRuntime removes a runtime from the docker config
128+ // RemoveRuntime removes a runtime from the containerd config
127129func (c * Config ) RemoveRuntime (name string ) error {
128130 if c == nil || c .Tree == nil {
129131 return nil
130132 }
131133
134+ // If using NVIDIA-specific configuration, handle file cleanup
135+ if c .nvidiaConfig != "" {
136+ // Check if all NVIDIA runtimes are being removed
137+ remainingNvidiaRuntimes := 0
138+ if runtimes := c .GetPath ([]string {"plugins" , c .CRIRuntimePluginName , "containerd" , "runtimes" }); runtimes != nil {
139+ if runtimesTree , ok := runtimes .(* toml.Tree ); ok {
140+ for _ , runtimeName := range runtimesTree .Keys () {
141+ if c .isNvidiaRuntime (runtimeName ) && runtimeName != name {
142+ remainingNvidiaRuntimes ++
143+ }
144+ }
145+ }
146+ }
147+
148+ // If this is the last NVIDIA runtime, remove the NVIDIA config file
149+ if remainingNvidiaRuntimes == 0 {
150+ if err := os .Remove (c .nvidiaConfig ); err != nil && ! os .IsNotExist (err ) {
151+ c .Logger .Warningf ("Failed to remove NVIDIA config file %s: %v" , c .nvidiaConfig , err )
152+ } else {
153+ c .Logger .Infof ("Removed NVIDIA config file: %s" , c .nvidiaConfig )
154+ }
155+ // Don't modify the in-memory tree when using NVIDIA-specific configuration
156+ return nil
157+ }
158+ }
159+
132160 config := * c .Tree
133161
134162 config .DeletePath ([]string {"plugins" , c .CRIRuntimePluginName , "containerd" , "runtimes" , name })
@@ -154,3 +182,134 @@ func (c *Config) RemoveRuntime(name string) error {
154182 * c .Tree = config
155183 return nil
156184}
185+
186+ // Save writes the config to the specified path or NVIDIA-specific config file
187+ func (c * Config ) Save (path string ) (int64 , error ) {
188+ if c .nvidiaConfig == "" {
189+ // Backward compatibility: save to main config
190+ return c .Tree .Save (path )
191+ }
192+
193+ // Ensure directory for NVIDIA config file exists
194+ dir := filepath .Dir (c .nvidiaConfig )
195+ if err := os .MkdirAll (dir , 0755 ); err != nil {
196+ return 0 , fmt .Errorf ("failed to create directory for NVIDIA config: %w" , err )
197+ }
198+
199+ // Save runtime configs to NVIDIA config file
200+ nvidiaConfig := c .extractRuntimeConfig ()
201+ n , err := nvidiaConfig .Save (c .nvidiaConfig )
202+ if err != nil {
203+ return n , fmt .Errorf ("failed to save NVIDIA config: %w" , err )
204+ }
205+
206+ // Update main config with imports directive
207+ if err := c .updateMainConfigImports (path ); err != nil {
208+ // Try to clean up the NVIDIA config file on error
209+ os .Remove (c .nvidiaConfig )
210+ return n , fmt .Errorf ("failed to update main config imports: %w" , err )
211+ }
212+
213+ c .Logger .Infof ("Wrote NVIDIA runtime configuration to: %s" , c .nvidiaConfig )
214+ return n , nil
215+ }
216+
217+ // extractRuntimeConfig creates a new config tree with only runtime configurations
218+ func (c * Config ) extractRuntimeConfig () * toml.Tree {
219+ config , _ := toml .TreeFromMap (map [string ]interface {}{
220+ "version" : c .Version ,
221+ })
222+
223+ // Extract runtime configurations for NVIDIA runtimes
224+ if runtimes := c .GetPath ([]string {"plugins" , c .CRIRuntimePluginName , "containerd" , "runtimes" }); runtimes != nil {
225+ if runtimesTree , ok := runtimes .(* toml.Tree ); ok {
226+ nvidiaRuntimes , _ := toml .TreeFromMap (map [string ]interface {}{})
227+ for _ , name := range runtimesTree .Keys () {
228+ if c .isNvidiaRuntime (name ) {
229+ if runtime := runtimesTree .Get (name ); runtime != nil {
230+ nvidiaRuntimes .Set (name , runtime )
231+ }
232+ }
233+ }
234+ if len (nvidiaRuntimes .Keys ()) > 0 {
235+ config .SetPath ([]string {"plugins" , c .CRIRuntimePluginName , "containerd" , "runtimes" }, nvidiaRuntimes )
236+ }
237+ }
238+ }
239+
240+ // Extract default runtime name if it's one of ours
241+ if defaultRuntime , ok := c .GetPath ([]string {"plugins" , c .CRIRuntimePluginName , "containerd" , "default_runtime_name" }).(string ); ok {
242+ if c .isNvidiaRuntime (defaultRuntime ) {
243+ config .SetPath ([]string {"plugins" , c .CRIRuntimePluginName , "containerd" , "default_runtime_name" }, defaultRuntime )
244+ }
245+ }
246+
247+ // Extract CDI enablement
248+ if cdiEnabled , ok := c .GetPath ([]string {"plugins" , c .CRIRuntimePluginName , "enable_cdi" }).(bool ); ok && cdiEnabled {
249+ config .SetPath ([]string {"plugins" , c .CRIRuntimePluginName , "enable_cdi" }, true )
250+ }
251+
252+ return config
253+ }
254+
255+ // updateMainConfigImports ensures the main config includes an imports directive
256+ func (c * Config ) updateMainConfigImports (path string ) error {
257+ // Load the main config file
258+ mainConfig , err := toml .FromFile (path ).Load ()
259+ if err != nil {
260+ // If the file doesn't exist, create a minimal config with imports
261+ if os .IsNotExist (err ) {
262+ mainConfig , _ = toml .TreeFromMap (map [string ]interface {}{
263+ "version" : c .Version ,
264+ })
265+ } else {
266+ return fmt .Errorf ("failed to load main config: %w" , err )
267+ }
268+ }
269+
270+ // Add imports directive if not present
271+ importPattern := c .nvidiaConfig
272+ imports := mainConfig .Get ("imports" )
273+ if imports == nil {
274+ mainConfig .Set ("imports" , []string {importPattern })
275+ } else if importsList , ok := imports .([]interface {}); ok {
276+ // Check if the import pattern already exists
277+ found := false
278+ for _ , imp := range importsList {
279+ if impStr , ok := imp .(string ); ok && impStr == importPattern {
280+ found = true
281+ break
282+ }
283+ }
284+ if ! found {
285+ // Add our import pattern
286+ importsList = append (importsList , importPattern )
287+ mainConfig .Set ("imports" , importsList )
288+ }
289+ } else if importsStrList , ok := imports .([]string ); ok {
290+ // Check if the import pattern already exists
291+ found := false
292+ for _ , imp := range importsStrList {
293+ if imp == importPattern {
294+ found = true
295+ break
296+ }
297+ }
298+ if ! found {
299+ // Add our import pattern
300+ importsStrList = append (importsStrList , importPattern )
301+ mainConfig .Set ("imports" , importsStrList )
302+ }
303+ } else {
304+ return fmt .Errorf ("unexpected imports type: %T" , imports )
305+ }
306+
307+ // Save the updated main config
308+ _ , err = mainConfig .Save (path )
309+ return err
310+ }
311+
312+ // isNvidiaRuntime checks if the runtime name is an NVIDIA runtime
313+ func (c * Config ) isNvidiaRuntime (name string ) bool {
314+ return name == "nvidia" || name == "nvidia-cdi" || name == "nvidia-legacy"
315+ }
0 commit comments