diff --git a/include/caffe/util/cudnn.hpp b/include/caffe/util/cudnn.hpp index a7d8dbba..cd3f93f6 100644 --- a/include/caffe/util/cudnn.hpp +++ b/include/caffe/util/cudnn.hpp @@ -41,6 +41,16 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { return "CUDNN_STATUS_NOT_SUPPORTED"; case CUDNN_STATUS_LICENSE_ERROR: return "CUDNN_STATUS_LICENSE_ERROR"; +#if CUDNN_VERSION_MIN(6, 0, 0) + case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING: + return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING"; +#endif +#if CUDNN_VERSION_MIN(7, 0, 0) + case CUDNN_STATUS_RUNTIME_IN_PROGRESS: + return "CUDNN_STATUS_RUNTIME_IN_PROGRESS"; + case CUDNN_STATUS_RUNTIME_FP_OVERFLOW: + return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW"; +#endif } return "Unknown cudnn status"; } @@ -109,8 +119,14 @@ template inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv, cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter, int pad_h, int pad_w, int stride_h, int stride_w) { +#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, + pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION, + dataType::type)); +#else + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); +#endif } template