@@ -20,6 +20,7 @@ import (
2020 "bytes"
2121 "context"
2222 "fmt"
23+ "math/rand"
2324 "regexp"
2425 "slices"
2526 "sort"
@@ -32,9 +33,11 @@ import (
3233 v4Converged "github.com/nutanix-cloud-native/prism-go-client/converged/v4"
3334 prismclientv3 "github.com/nutanix-cloud-native/prism-go-client/v3"
3435 prismclientv4 "github.com/nutanix-cloud-native/prism-go-client/v4"
35- clustermgmtconfig "github.com/nutanix/ntnx-api-golang-clients/clustermgmt-go-client/v4/models/clustermgmt/v4/config"
36+ clusterModels "github.com/nutanix/ntnx-api-golang-clients/clustermgmt-go-client/v4/models/clustermgmt/v4/config"
3637 subnetModels "github.com/nutanix/ntnx-api-golang-clients/networking-go-client/v4/models/networking/v4/config"
3738 prismModels "github.com/nutanix/ntnx-api-golang-clients/prism-go-client/v4/models/prism/v4/config"
39+ vmmconfig "github.com/nutanix/ntnx-api-golang-clients/vmm-go-client/v4/models/vmm/v4/ahv/config"
40+ imageModels "github.com/nutanix/ntnx-api-golang-clients/vmm-go-client/v4/models/vmm/v4/content"
3841 prismconfig "github.com/nutanix/ntnx-api-golang-clients/volumes-go-client/v4/models/prism/v4/config"
3942 volumesconfig "github.com/nutanix/ntnx-api-golang-clients/volumes-go-client/v4/models/volumes/v4/config"
4043 "k8s.io/apimachinery/pkg/api/resource"
@@ -46,16 +49,13 @@ import (
4649
4750 infrav1 "github.com/nutanix-cloud-native/cluster-api-provider-nutanix/api/v1beta1"
4851 nutanixclient "github.com/nutanix-cloud-native/cluster-api-provider-nutanix/pkg/client"
49- imageModels "github.com/nutanix/ntnx-api-golang-clients/vmm-go-client/v4/models/vmm/v4/content"
5052)
5153
5254const (
5355 providerIdPrefix = "nutanix://"
5456
5557 subnetTypeOverlay = "OVERLAY"
5658
57- gpuUnused = "UNUSED"
58-
5959 detachVGRequeueAfter = 30 * time .Second
6060
6161 ImageStateDeletePending = "DELETE_PENDING"
@@ -226,7 +226,7 @@ func GetPEUUID(ctx context.Context, client *v4Converged.Client, peName, peUUID *
226226 return "" , err
227227 }
228228 // Validate filtered PEs
229- foundPEs := make ([]clustermgmtconfig .Cluster , 0 )
229+ foundPEs := make ([]clusterModels .Cluster , 0 )
230230 for _ , s := range responsePEs {
231231 if strings .EqualFold (* s .Name , * peName ) && hasPEClusterServiceEnabled (& s ) {
232232 foundPEs = append (foundPEs , s )
@@ -813,23 +813,23 @@ func GetProjectUUID(ctx context.Context, client *prismclientv3.Client, projectNa
813813 return foundProjectUUID , nil
814814}
815815
816- func hasPEClusterServiceEnabled (peCluster * clustermgmtconfig .Cluster ) bool {
816+ func hasPEClusterServiceEnabled (peCluster * clusterModels .Cluster ) bool {
817817 if peCluster .Config == nil ||
818818 peCluster .Config .ClusterFunction == nil {
819819 return false
820820 }
821821 serviceList := peCluster .Config .ClusterFunction
822822 for _ , s := range serviceList {
823- if strings .ToUpper (string (s .GetName ())) == clustermgmtconfig .CLUSTERFUNCTIONREF_AOS .GetName () {
823+ if strings .ToUpper (string (s .GetName ())) == clusterModels .CLUSTERFUNCTIONREF_AOS .GetName () {
824824 return true
825825 }
826826 }
827827 return false
828828}
829829
830830// GetGPUList returns a list of GPU device IDs for the given list of GPUs
831- func GetGPUList (ctx context.Context , client * prismclientv3 .Client , gpus []infrav1.NutanixGPU , peUUID string ) ([]* prismclientv3. VMGpu , error ) {
832- resultGPUs := make ([]* prismclientv3. VMGpu , 0 )
831+ func GetGPUList (ctx context.Context , client * v4Converged .Client , gpus []infrav1.NutanixGPU , peUUID string ) ([]* vmmconfig. Gpu , error ) {
832+ resultGPUs := make ([]* vmmconfig. Gpu , 0 )
833833 for _ , gpu := range gpus {
834834 foundGPU , err := GetGPU (ctx , client , peUUID , gpu )
835835 if err != nil {
@@ -841,63 +841,128 @@ func GetGPUList(ctx context.Context, client *prismclientv3.Client, gpus []infrav
841841}
842842
843843// GetGPUDeviceID returns the device ID of a GPU with the given name
844- func GetGPU (ctx context.Context , client * prismclientv3 .Client , peUUID string , gpu infrav1.NutanixGPU ) (* prismclientv3. VMGpu , error ) {
844+ func GetGPU (ctx context.Context , client * v4Converged .Client , peUUID string , gpu infrav1.NutanixGPU ) (* vmmconfig. Gpu , error ) {
845845 gpuDeviceID := gpu .DeviceID
846846 gpuDeviceName := gpu .Name
847847 if gpuDeviceID == nil && gpuDeviceName == nil {
848848 return nil , fmt .Errorf ("gpu name or gpu device ID must be passed in order to retrieve the GPU" )
849849 }
850- allGPUs , err := GetGPUsForPE (ctx , client , peUUID )
850+
851+ allUnusedGPUs , err := GetGPUsForPE (ctx , client , peUUID , gpu )
851852 if err != nil {
852853 return nil , err
853854 }
854- if len (allGPUs ) == 0 {
855+ if len (allUnusedGPUs ) == 0 {
855856 return nil , fmt .Errorf ("no available GPUs found in Prism Element cluster with UUID %s" , peUUID )
856857 }
857- for _ , peGPU := range allGPUs {
858- if peGPU .Status != gpuUnused {
858+
859+ randomIndex := rand .Intn (len (allUnusedGPUs ))
860+ return allUnusedGPUs [randomIndex ], nil
861+ }
862+
863+ func GetGPUsForPE (ctx context.Context , client * v4Converged.Client , peUUID string , gpu infrav1.NutanixGPU ) ([]* vmmconfig.Gpu , error ) {
864+ var filter string
865+ var gpus []* vmmconfig.Gpu
866+
867+ if gpu .DeviceID != nil {
868+ filter = fmt .Sprintf ("physicalGpuConfig/deviceId eq %d" , * gpu .DeviceID )
869+ } else if gpu .Name != nil {
870+ filter = fmt .Sprintf ("physicalGpuConfig/deviceName eq '%s'" , * gpu .Name )
871+ }
872+
873+ physicalGPUs , err := client .Clusters .ListClusterPhysicalGPUs (ctx , peUUID , converged .WithFilter (filter ))
874+ if err != nil {
875+ return nil , err
876+ }
877+ for _ , physicalGPU := range physicalGPUs {
878+ if physicalGPU .PhysicalGpuConfig .IsInUse != nil && * physicalGPU .PhysicalGpuConfig .IsInUse {
859879 continue
860880 }
861- if (gpuDeviceID != nil && * peGPU .DeviceID == * gpuDeviceID ) || (gpuDeviceName != nil && * gpuDeviceName == peGPU .Name ) {
862- return & prismclientv3.VMGpu {
863- DeviceID : peGPU .DeviceID ,
864- Mode : & peGPU .Mode ,
865- Vendor : & peGPU .Vendor ,
866- }, err
881+
882+ vmGpu := vmmconfig .NewGpu ()
883+ vmGpu .Name = physicalGPU .PhysicalGpuConfig .DeviceName
884+ vmGpu .DeviceId = ptr .To (int (* physicalGPU .PhysicalGpuConfig .DeviceId ))
885+ vmGpu .Mode = vmmconfig .GPUMODE_PASSTHROUGH_COMPUTE .Ref ()
886+ if physicalGPU .PhysicalGpuConfig .Type != nil && * physicalGPU .PhysicalGpuConfig .Type == clusterModels .GPUTYPE_PASSTHROUGH_GRAPHICS {
887+ vmGpu .Mode = vmmconfig .GPUMODE_PASSTHROUGH_GRAPHICS .Ref ()
867888 }
889+ vmGpu .Vendor = gpuVendorStringToGpuVendor (* physicalGPU .PhysicalGpuConfig .VendorName )
890+ gpus = append (gpus , vmGpu )
868891 }
869- return nil , fmt .Errorf ("no available GPU found in Prism Element that matches required GPU inputs" )
870- }
871892
872- func GetGPUsForPE (ctx context.Context , client * prismclientv3.Client , peUUID string ) ([]* prismclientv3.GPU , error ) {
873- gpus := make ([]* prismclientv3.GPU , 0 )
874- // We use ListHost, because it returns all hosts, since the endpoint does not support pagination,
875- // and ListAllHost incorrectly handles pagination. https://jira.nutanix.com/browse/NCN-110045
876- hosts , err := client .V3 .ListHost (ctx , & prismclientv3.DSMetadata {})
877- if err != nil {
878- return gpus , err
893+ if gpu .Name != nil {
894+ filter = fmt .Sprintf ("virtualGpuConfig/deviceName eq '%s'" , * gpu .Name )
895+ } else if gpu .DeviceID != nil {
896+ filter = fmt .Sprintf ("virtualGpuConfig/deviceId eq %d" , * gpu .DeviceID )
879897 }
880898
881- for _ , host := range hosts .Entities {
882- if host == nil ||
883- host .Status == nil ||
884- host .Status .ClusterReference == nil ||
885- host .Status .Resources == nil ||
886- len (host .Status .Resources .GPUList ) == 0 ||
887- host .Status .ClusterReference .UUID != peUUID {
899+ virtualGPUs , err := client .Clusters .ListClusterVirtualGPUs (ctx , peUUID , converged .WithFilter (filter ))
900+ if err != nil {
901+ return nil , err
902+ }
903+ for _ , virtualGPU := range virtualGPUs {
904+ if virtualGPU .VirtualGpuConfig .IsInUse != nil && * virtualGPU .VirtualGpuConfig .IsInUse {
888905 continue
889906 }
890907
891- for _ , peGpu := range host . Status . Resources . GPUList {
892- if peGpu == nil {
893- continue
894- }
895- gpus = append ( gpus , peGpu )
896- }
908+ vmGpu := vmmconfig . NewGpu ()
909+ vmGpu . Name = virtualGPU . VirtualGpuConfig . DeviceName
910+ vmGpu . DeviceId = ptr . To ( int ( * virtualGPU . VirtualGpuConfig . DeviceId ))
911+ vmGpu . Mode = vmmconfig . GPUMODE_VIRTUAL . Ref ()
912+ vmGpu . Vendor = gpuVendorStringToGpuVendor ( * virtualGPU . VirtualGpuConfig . VendorName )
913+ gpus = append ( gpus , vmGpu )
897914 }
898915 return gpus , nil
899916}
900917
918+ func gpuVendorStringToGpuVendor (vendor string ) * vmmconfig.GpuVendor {
919+ switch vendor {
920+ case "kNvidia" :
921+ return vmmconfig .GPUVENDOR_NVIDIA .Ref ()
922+ case "kIntel" :
923+ return vmmconfig .GPUVENDOR_INTEL .Ref ()
924+ case "kAmd" :
925+ return vmmconfig .GPUVENDOR_AMD .Ref ()
926+ default :
927+ return vmmconfig .GPUVENDOR_UNKNOWN .Ref ()
928+ }
929+ }
930+
931+ // TODO: delete when VM part will be migrated to use the v4Converged client
932+ // v4GpuToV3Gpu converts a v4 GPU to a v3 GPU
933+ func v4GpuToV3Gpu (gpu * vmmconfig.Gpu ) * prismclientv3.VMGpu {
934+ var mode string
935+ var vendor string
936+
937+ switch * gpu .Mode {
938+ case vmmconfig .GPUMODE_PASSTHROUGH_COMPUTE :
939+ mode = "PASSTHROUGH_COMPUTE"
940+ case vmmconfig .GPUMODE_PASSTHROUGH_GRAPHICS :
941+ mode = "PASSTHROUGH_GRAPHICS"
942+ case vmmconfig .GPUMODE_VIRTUAL :
943+ mode = "VIRTUAL"
944+ default :
945+ mode = "$UNKNOWN"
946+ }
947+
948+ switch * gpu .Vendor {
949+ case vmmconfig .GPUVENDOR_NVIDIA :
950+ vendor = "NVIDIA"
951+ case vmmconfig .GPUVENDOR_INTEL :
952+ vendor = "INTEL"
953+ case vmmconfig .GPUVENDOR_AMD :
954+ vendor = "AMD"
955+ default :
956+ vendor = "UNKNOWN"
957+ }
958+
959+ return & prismclientv3.VMGpu {
960+ DeviceID : ptr .To (int64 (* gpu .DeviceId )),
961+ Mode : ptr .To (mode ),
962+ Vendor : ptr .To (vendor ),
963+ }
964+ }
965+
901966// GetLegacyFailureDomainFromNutanixCluster gets the failure domain with a given name from a NutanixCluster object.
902967func GetLegacyFailureDomainFromNutanixCluster (failureDomainName string , nutanixCluster * infrav1.NutanixCluster ) * infrav1.NutanixFailureDomainConfig { //nolint:staticcheck // suppress complaining on Deprecated type
903968 for _ , fd := range nutanixCluster .Spec .FailureDomains { //nolint:staticcheck // suppress complaining on Deprecated field
@@ -908,7 +973,7 @@ func GetLegacyFailureDomainFromNutanixCluster(failureDomainName string, nutanixC
908973 return nil
909974}
910975
911- func GetStorageContainerInCluster (ctx context.Context , client * v4Converged.Client , storageContainerIdentifier , clusterIdentifier infrav1.NutanixResourceIdentifier ) (* clustermgmtconfig .StorageContainer , error ) {
976+ func GetStorageContainerInCluster (ctx context.Context , client * v4Converged.Client , storageContainerIdentifier , clusterIdentifier infrav1.NutanixResourceIdentifier ) (* clusterModels .StorageContainer , error ) {
912977 var filter , identifier string
913978 switch {
914979 case storageContainerIdentifier .IsUUID ():
0 commit comments