@@ -18,9 +18,12 @@ package root
1818
1919import (
2020 "fmt"
21+ "io"
22+ "os"
2123
2224 "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
2325 "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
26+ "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
2427 "github.com/sirupsen/logrus"
2528 "github.com/urfave/cli/v2"
2629)
@@ -32,73 +35,125 @@ type loadSaver interface {
3235
3336type command struct {
3437 logger * logrus.Logger
38+ }
3539
36- handler loadSaver
40+ type transformOptions struct {
41+ input string
42+ output string
3743}
3844
39- type config struct {
45+ type options struct {
46+ transformOptions
4047 from string
4148 to string
4249}
4350
4451// NewCommand constructs a generate-cdi command with the specified logger
45- func NewCommand (logger * logrus.Logger , specHandler loadSaver ) * cli.Command {
52+ func NewCommand (logger * logrus.Logger ) * cli.Command {
4653 c := command {
47- logger : logger ,
48- handler : specHandler ,
54+ logger : logger ,
4955 }
5056 return c .build ()
5157}
5258
5359// build creates the CLI command
5460func (m command ) build () * cli.Command {
55- cfg := config {}
61+ opts := options {}
5662
5763 c := cli.Command {
5864 Name : "root" ,
5965 Usage : "Apply a root transform to a CDI specification" ,
6066 Before : func (c * cli.Context ) error {
61- return m .validateFlags (c , & cfg )
67+ return m .validateFlags (c , & opts )
6268 },
6369 Action : func (c * cli.Context ) error {
64- return m .run (c , & cfg )
70+ return m .run (c , & opts )
6571 },
6672 }
6773
6874 c .Flags = []cli.Flag {
75+ & cli.StringFlag {
76+ Name : "input" ,
77+ Usage : "Specify the file to read the CDI specification from. If this is '-' the specification is read from STDIN" ,
78+ Value : "-" ,
79+ Destination : & opts .input ,
80+ },
81+ & cli.StringFlag {
82+ Name : "output" ,
83+ Usage : "Specify the file to output the generated CDI specification to. If this is '' the specification is output to STDOUT" ,
84+ Destination : & opts .output ,
85+ },
6986 & cli.StringFlag {
7087 Name : "from" ,
7188 Usage : "specify the root to be transformed" ,
72- Destination : & cfg .from ,
89+ Destination : & opts .from ,
7390 },
7491 & cli.StringFlag {
7592 Name : "to" ,
7693 Usage : "specify the replacement root. If this is the same as the from root, the transform is a no-op." ,
7794 Value : "" ,
78- Destination : & cfg .to ,
95+ Destination : & opts .to ,
7996 },
8097 }
8198
8299 return & c
83100}
84101
85- func (m command ) validateFlags (c * cli.Context , cfg * config ) error {
102+ func (m command ) validateFlags (c * cli.Context , opts * options ) error {
86103 return nil
87104}
88105
89- func (m command ) run (c * cli.Context , cfg * config ) error {
90- spec , err := m . handler .Load ()
106+ func (m command ) run (c * cli.Context , opts * options ) error {
107+ spec , err := opts .Load ()
91108 if err != nil {
92109 return fmt .Errorf ("failed to load CDI specification: %w" , err )
93110 }
94111
95112 err = transform .NewRootTransformer (
96- cfg .from ,
97- cfg .to ,
113+ opts .from ,
114+ opts .to ,
98115 ).Transform (spec .Raw ())
99116 if err != nil {
100117 return fmt .Errorf ("failed to transform CDI specification: %w" , err )
101118 }
102119
103- return m .handler .Save (spec )
120+ return opts .Save (spec )
121+ }
122+
123+ // Load lodas the input CDI specification
124+ func (o transformOptions ) Load () (spec.Interface , error ) {
125+ contents , err := o .getContents ()
126+ if err != nil {
127+ return nil , fmt .Errorf ("failed to read spec contents: %v" , err )
128+ }
129+
130+ raw , err := cdi .ParseSpec (contents )
131+ if err != nil {
132+ return nil , fmt .Errorf ("failed to parse CDI spec: %v" , err )
133+ }
134+
135+ return spec .New (
136+ spec .WithRawSpec (raw ),
137+ )
138+ }
139+
140+ func (o transformOptions ) getContents () ([]byte , error ) {
141+ if o .input == "-" {
142+ return io .ReadAll (os .Stdin )
143+ }
144+
145+ return os .ReadFile (o .input )
146+ }
147+
148+ // Save saves the CDI specification to the output file
149+ func (o transformOptions ) Save (s spec.Interface ) error {
150+ if o .output == "" {
151+ _ , err := s .WriteTo (os .Stdout )
152+ if err != nil {
153+ return fmt .Errorf ("failed to write CDI spec to STDOUT: %v" , err )
154+ }
155+ return nil
156+ }
157+
158+ return s .Save (o .output )
104159}
0 commit comments