1+ #!/usr/bin/env python3
2+
3+ # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+ # SPDX-License-Identifier: Apache-2.0
5+
6+ import json
7+ import yaml
8+ import sys
9+ import os
10+ from typing import Dict , List , Any
11+
12+ def write_output (key : str , value : str ):
13+ """Write GitHub Actions output."""
14+ github_output = os .environ .get ('GITHUB_OUTPUT' , '/dev/null' )
15+ with open (github_output , 'a' ) as f :
16+ f .write (f"{ key } ={ value } \n " )
17+ print (f"{ key } ={ value } " )
18+
19+ def write_json_file (filename : str , data : Any ):
20+ """Write data to a JSON file."""
21+ with open (filename , 'w' ) as f :
22+ json .dump (data , f , indent = 2 )
23+
24+ def write_text_file (filename : str , content : str ):
25+ """Write content to a text file."""
26+ with open (filename , 'w' ) as f :
27+ f .write (content )
28+
29+ def explode_std_versions (matrix_entries : List [Dict [str , Any ]]) -> List [Dict [str , Any ]]:
30+ """Explode std arrays into individual entries."""
31+ result = []
32+ for entry in matrix_entries :
33+ if 'std' in entry and isinstance (entry ['std' ], list ):
34+ for std in entry ['std' ]:
35+ new_entry = entry .copy ()
36+ new_entry ['std' ] = std
37+ result .append (new_entry )
38+ else :
39+ result .append (entry )
40+ return result
41+
42+ def extract_matrix (file_path : str , matrix_type : str ):
43+ """Extract and process the matrix configuration."""
44+ try :
45+ with open (file_path , 'r' ) as f :
46+ data = yaml .safe_load (f )
47+
48+ if matrix_type not in data :
49+ print (f"Error: Matrix type '{ matrix_type } ' not found in { file_path } " , file = sys .stderr )
50+ sys .exit (1 )
51+
52+ matrix = data [matrix_type ]
53+
54+ # Write devcontainer version
55+ devcontainer_version = data .get ('devcontainer_version' , '25.08' )
56+ write_output ("DEVCONTAINER_VERSION" , devcontainer_version )
57+
58+ # Process nvcc matrix
59+ if 'nvcc' not in matrix :
60+ print (f"Error: 'nvcc' section not found in { matrix_type } matrix" , file = sys .stderr )
61+ sys .exit (1 )
62+
63+ nvcc_matrix = matrix ['nvcc' ]
64+ nvcc_full_matrix = explode_std_versions (nvcc_matrix )
65+
66+ write_output ("NVCC_FULL_MATRIX" , json .dumps (nvcc_full_matrix ))
67+
68+ # Extract unique CUDA versions
69+ cuda_versions = list (set (entry ['cuda' ] for entry in nvcc_full_matrix ))
70+ cuda_versions .sort ()
71+ write_output ("CUDA_VERSIONS" , json .dumps (cuda_versions ))
72+
73+ # Extract unique host compilers
74+ host_compilers = list (set (entry ['compiler' ]['name' ] for entry in nvcc_full_matrix ))
75+ host_compilers .sort ()
76+ write_output ("HOST_COMPILERS" , json .dumps (host_compilers ))
77+
78+ # Create per-cuda-compiler matrix
79+ per_cuda_compiler = {}
80+ for entry in nvcc_full_matrix :
81+ key = f"{ entry ['cuda' ]} -{ entry ['compiler' ]['name' ]} "
82+ if key not in per_cuda_compiler :
83+ per_cuda_compiler [key ] = []
84+ per_cuda_compiler [key ].append (entry )
85+
86+ write_output ("PER_CUDA_COMPILER_MATRIX" , json .dumps (per_cuda_compiler ))
87+
88+ # Create output directory and write detailed files (CCCL approach)
89+ os .makedirs ("workflow" , exist_ok = True )
90+
91+ # Write individual output files for debugging and artifacts
92+ write_json_file ("workflow/devcontainer_version.json" , {"version" : devcontainer_version })
93+ write_json_file ("workflow/nvcc_full_matrix.json" , nvcc_full_matrix )
94+ write_json_file ("workflow/cuda_versions.json" , cuda_versions )
95+ write_json_file ("workflow/host_compilers.json" , host_compilers )
96+ write_json_file ("workflow/per_cuda_compiler_matrix.json" , per_cuda_compiler )
97+
98+ # Write summary
99+ summary = {
100+ "total_matrix_entries" : len (nvcc_full_matrix ),
101+ "cuda_compiler_combinations" : len (per_cuda_compiler ),
102+ "cuda_versions" : cuda_versions ,
103+ "host_compilers" : host_compilers
104+ }
105+ write_json_file ("workflow/matrix_summary.json" , summary )
106+
107+ # Write human-readable summary
108+ summary_text = f"Matrix Summary:\n "
109+ summary_text += f" Total matrix entries: { len (nvcc_full_matrix )} \n "
110+ summary_text += f" CUDA versions: { ', ' .join (cuda_versions )} \n "
111+ summary_text += f" Host compilers: { ', ' .join (host_compilers )} \n "
112+ summary_text += f" CUDA-compiler combinations: { len (per_cuda_compiler )} \n \n "
113+ summary_text += "Combinations:\n "
114+ for key , entries in per_cuda_compiler .items ():
115+ summary_text += f" { key } : { len (entries )} entries\n "
116+
117+ write_text_file ("workflow/matrix_summary.txt" , summary_text )
118+
119+ print (f"Successfully processed { len (nvcc_full_matrix )} matrix entries" , file = sys .stderr )
120+ print (f"Generated { len (per_cuda_compiler )} cuda-compiler combinations" , file = sys .stderr )
121+ print ("Matrix data written to workflow/ directory" , file = sys .stderr )
122+
123+ except FileNotFoundError :
124+ print (f"Error: Matrix file '{ file_path } ' not found" , file = sys .stderr )
125+ sys .exit (1 )
126+ except yaml .YAMLError as e :
127+ print (f"Error parsing YAML file '{ file_path } ': { e } " , file = sys .stderr )
128+ sys .exit (1 )
129+ except KeyError as e :
130+ print (f"Error: Missing required key in matrix file: { e } " , file = sys .stderr )
131+ sys .exit (1 )
132+ except Exception as e :
133+ print (f"Unexpected error processing matrix: { e } " , file = sys .stderr )
134+ sys .exit (1 )
135+
136+ def main ():
137+ if len (sys .argv ) != 3 :
138+ print ("Usage: compute-matrix.py MATRIX_FILE MATRIX_TYPE" , file = sys .stderr )
139+ print (" MATRIX_FILE : The path to the matrix file." , file = sys .stderr )
140+ print (" MATRIX_TYPE : The desired matrix. Supported values: 'pull_request'" , file = sys .stderr )
141+ sys .exit (1 )
142+
143+ matrix_file = sys .argv [1 ]
144+ matrix_type = sys .argv [2 ]
145+
146+ if matrix_type != "pull_request" :
147+ print (f"Error: Unsupported matrix type '{ matrix_type } '. Only 'pull_request' is supported." , file = sys .stderr )
148+ sys .exit (1 )
149+
150+ print (f"Input matrix file: { matrix_file } " , file = sys .stderr )
151+ print (f"Matrix Type: { matrix_type } " , file = sys .stderr )
152+
153+ # Show matrix file content for debugging
154+ try :
155+ with open (matrix_file , 'r' ) as f :
156+ content = f .read ()
157+ print ("Matrix file content:" , file = sys .stderr )
158+ print (content , file = sys .stderr )
159+ print ("=" * 50 , file = sys .stderr )
160+ except Exception as e :
161+ print (f"Warning: Could not read matrix file for debugging: { e } " , file = sys .stderr )
162+
163+ extract_matrix (matrix_file , matrix_type )
164+
165+ if __name__ == "__main__" :
166+ main ()
0 commit comments