@@ -100,7 +100,7 @@ func NewRunner(opts ...runnerOption) Runner {
100100// NewNestedContainerRunner creates a new nested container runner.
101101// A nested container runs a container inside another container based on a
102102// given runner (remote or local).
103- func NewNestedContainerRunner (runner Runner , baseImage string , mountToolkitFromHost bool , containerName string , cacheDir string ) (Runner , error ) {
103+ func NewNestedContainerRunner (runner Runner , baseImage string , mountToolkitFromHost bool , containerName string , cacheDir string , requiresGPUs bool ) (Runner , error ) {
104104 // If a container with the same name exists from a previous test run, remove it first.
105105 // Ignore errors as container might not exist
106106 _ , _ , err := runner .Run (fmt .Sprintf ("docker rm -f %s 2>/dev/null || true" , containerName ))
@@ -110,6 +110,16 @@ func NewNestedContainerRunner(runner Runner, baseImage string, mountToolkitFromH
110110
111111 var additionalContainerArguments []string
112112
113+ if requiresGPUs {
114+ // If the container requires access to GPUs we explicitly add the nvidia
115+ // runtime and set `NVIDIA_VISIBLE_DEVICES` to trigger jit-cdi spec
116+ // generation.
117+ additionalContainerArguments = append (additionalContainerArguments ,
118+ "--runtime=nvidia" ,
119+ "-e NVIDIA_VISIBLE_DEVICES=runtime.nvidia.com/gpu=all" ,
120+ )
121+ }
122+
113123 if cacheDir != "" {
114124 additionalContainerArguments = append (additionalContainerArguments ,
115125 "-v " + cacheDir + ":" + cacheDir + ":ro" ,
@@ -302,10 +312,6 @@ func connectOrDie(sshKey, sshUser, host, port string) (*ssh.Client, error) {
302312
303313// outerContainerTemplate represents a template to start a container with
304314// a name specified.
305- // The container is given access to all NVIDIA gpus by explicitly using the
306- // nvidia runtime and the `runtime.nvidia.com/gpu=all` device to trigger JIT
307- // CDI spec generation.
308- // The template also allows for additional arguments to be specified.
309315type outerContainer struct {
310316 Name string
311317 BaseImage string
@@ -314,8 +320,6 @@ type outerContainer struct {
314320
315321func (o * outerContainer ) Render () (string , error ) {
316322 tmpl , err := template .New ("startContainer" ).Parse (`docker run -d --name {{.Name}} --privileged \
317- -e NVIDIA_VISIBLE_DEVICES=runtime.nvidia.com/gpu=all \
318- -e NVIDIA_DRIVER_CAPABILITIES=all \
319323{{ range $i, $a := .AdditionalArguments -}}
320324{{ $a }} \
321325{{ end -}}
0 commit comments