@@ -61,16 +61,58 @@ static bool load_cuda_modules() {
61
61
return true ;
62
62
}
63
63
64
+ static bool check_cuda_device () {
65
+ int device_count = 0 ;
66
+ int driver_version = 0 ;
67
+
68
+ CUCTX_CUDA_CALL_ERROR (cuDriverGetVersion (&driver_version));
69
+ if (kRequiredDriverVersion > driver_version) {
70
+ RTC_LOG (LS_ERROR)
71
+ << " CUDA driver version is not higher than the required version. "
72
+ << driver_version;
73
+ return false ;
74
+ }
75
+
76
+ CUresult result = cuInit (0 );
77
+ if (result != CUDA_SUCCESS) {
78
+ RTC_LOG (LS_ERROR) << " Failed to initialize CUDA." ;
79
+ return false ;
80
+ }
81
+
82
+ result = cuDeviceGetCount (&device_count);
83
+ if (result != CUDA_SUCCESS) {
84
+ RTC_LOG (LS_ERROR) << " Failed to get CUDA device count." ;
85
+ return false ;
86
+ }
87
+
88
+ if (device_count == 0 ) {
89
+ RTC_LOG (LS_ERROR) << " No CUDA devices found." ;
90
+ return false ;
91
+ }
92
+
93
+ return true ;
94
+ }
95
+
96
+ CudaContext* CudaContext::GetInstance () {
97
+ static CudaContext instance;
98
+ return &instance;
99
+ }
100
+
101
+ bool CudaContext::IsAvailable () {
102
+ return load_cuda_modules () && check_cuda_device ();
103
+ }
104
+
64
105
bool CudaContext::Initialize () {
65
106
// Initialize CUDA context
66
107
67
108
bool success = load_cuda_modules ();
68
109
if (!success) {
69
- std::cout << " Failed to load CUDA modules. maybe the NVIDIA driver is not installed?" << std::endl;
110
+ RTC_LOG (LS_ERROR) << " Failed to load CUDA modules. maybe the NVIDIA driver "
111
+ " is not installed?" ;
70
112
return false ;
71
113
}
72
114
73
- int numDevices = 0 ;
115
+ int num_devices = 0 ;
74
116
CUdevice cu_device = 0 ;
75
117
CUcontext context = nullptr ;
76
118
@@ -84,7 +126,23 @@ bool CudaContext::Initialize() {
84
126
return false ;
85
127
}
86
128
87
- CUCTX_CUDA_CALL_ERROR (cuInit (0 ));
129
+ CUresult result = cuInit (0 );
130
+ if (result != CUDA_SUCCESS) {
131
+ RTC_LOG (LS_ERROR) << " Failed to initialize CUDA." ;
132
+ return false ;
133
+ }
134
+
135
+ result = cuDeviceGetCount (&num_devices);
136
+ if (result != CUDA_SUCCESS) {
137
+ RTC_LOG (LS_ERROR) << " Failed to get CUDA device count." ;
138
+ return false ;
139
+ }
140
+
141
+ if (num_devices == 0 ) {
142
+ RTC_LOG (LS_ERROR) << " No CUDA devices found." ;
143
+ return false ;
144
+ }
145
+
88
146
CUCTX_CUDA_CALL_ERROR (cuDeviceGet (&cu_device, 0 ));
89
147
90
148
char device_name[80 ];
@@ -104,6 +162,22 @@ bool CudaContext::Initialize() {
104
162
return true ;
105
163
}
106
164
165
+ CUcontext CudaContext::GetContext () const {
166
+ RTC_DCHECK (cu_context_ != nullptr );
167
+ // Ensure the context is current
168
+ CUcontext current;
169
+ if (cuCtxGetCurrent (¤t) != CUDA_SUCCESS) {
170
+ throw ;
171
+ }
172
+ if (cu_context_ == current) {
173
+ return cu_context_;
174
+ }
175
+ if (cuCtxSetCurrent (cu_context_) != CUDA_SUCCESS) {
176
+ throw ;
177
+ }
178
+ return cu_context_;
179
+ }
180
+
107
181
void CudaContext::Shutdown () {
108
182
// Shutdown CUDA context
109
183
if (cu_context_) {
0 commit comments