Skip to content

Commit 73e280a

Browse files
authored
fix: gpu resource device count calculation (#107)
1 parent 8cd4657 commit 73e280a

File tree

3 files changed

+76
-5
lines changed

3 files changed

+76
-5
lines changed

pkg/scheduler/api/resource_info/gpu_resource_requirment.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,16 @@ func NewGpuResourceRequirement() *GpuResourceRequirement {
4040

4141
func NewGpuResourceRequirementWithGpus(gpus float64, gpuMemory int64) *GpuResourceRequirement {
4242
gResource := &GpuResourceRequirement{
43-
count: fractionDefaultCount,
43+
count: 0,
4444
portion: gpus,
4545
gpuMemory: gpuMemory,
4646
migResources: make(map[v1.ResourceName]int64),
4747
}
4848
if gpus >= wholeGpuPortion {
4949
gResource.count = int64(gpus)
5050
gResource.portion = wholeGpuPortion
51+
} else if gpus > 0 || gpuMemory > 0 { // Fraction
52+
gResource.count = fractionDefaultCount
5153
}
5254
return gResource
5355
}

pkg/scheduler/api/resource_info/gpu_resource_requirment_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,23 @@ var _ = Describe("GpuResourceRequirement mechanism", func() {
150150
Expect(gpuResource1.LessEqual(gpuResource2)).To(BeTrue())
151151
})
152152
})
153+
154+
Context("GetNumOfGpuDevices", func() {
155+
It("should return 0 for 0 GPUs", func() {
156+
gpuResource := NewGpuResourceRequirementWithGpus(0, 0)
157+
Expect(gpuResource.GetNumOfGpuDevices()).To(Equal(int64(0)))
158+
})
159+
160+
It("should return 1 for 0.5 GPUs", func() {
161+
gpuResource := NewGpuResourceRequirementWithGpus(0.5, 0)
162+
Expect(gpuResource.GetNumOfGpuDevices()).To(Equal(int64(1)))
163+
})
164+
165+
It("should return 2 for 2 GPUs", func() {
166+
gpuResource := NewGpuResourceRequirementWithGpus(2, 0)
167+
Expect(gpuResource.GetNumOfGpuDevices()).To(Equal(int64(2)))
168+
})
169+
})
153170
})
154171

155172
func newGpuResourceRequirementWithValues(

pkg/scheduler/plugins/proportion/capacity_policy/max_allowed_check_test.go

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@ package capacity_policy
66
import (
77
. "github.com/onsi/ginkgo/v2"
88
. "github.com/onsi/gomega"
9+
v1 "k8s.io/api/core/v1"
10+
"k8s.io/apimachinery/pkg/api/resource"
911
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1012

13+
"reflect"
14+
15+
"github.com/NVIDIA/KAI-scheduler/pkg/apis/scheduling/v2alpha2"
16+
"github.com/NVIDIA/KAI-scheduler/pkg/common/constants"
1117
commonconstants "github.com/NVIDIA/KAI-scheduler/pkg/common/constants"
1218
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/common_info"
1319
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
@@ -205,10 +211,11 @@ var _ = Describe("Max Allowed Policy Check", func() {
205211
Describe("resultsOverLimit", func() {
206212
Context("resultsOverLimit tests", func() {
207213
tests := map[string]struct {
208-
queues map[common_info.QueueID]*rs.QueueAttributes
209-
job *podgroup_info.PodGroupInfo
210-
requestedShare rs.ResourceQuantities
211-
expectedResult bool
214+
queues map[common_info.QueueID]*rs.QueueAttributes
215+
job *podgroup_info.PodGroupInfo
216+
requestedShare rs.ResourceQuantities
217+
expectedResult bool
218+
expectedDetails *v2alpha2.QuotaDetails // optional, only compare if set
212219
}{
213220
"single queue - unlimited queue": {
214221
queues: map[common_info.QueueID]*rs.QueueAttributes{
@@ -288,6 +295,39 @@ var _ = Describe("Max Allowed Policy Check", func() {
288295
},
289296
expectedResult: false,
290297
},
298+
"single queue - limited queue - results above limit 0 allocated": {
299+
queues: map[common_info.QueueID]*rs.QueueAttributes{
300+
"queue1": {
301+
UID: "queue1",
302+
Name: "queue1",
303+
ParentQueue: "",
304+
ChildQueues: nil,
305+
CreationTimestamp: metav1.Time{},
306+
QueueResourceShare: rs.QueueResourceShare{
307+
GPU: rs.ResourceShare{
308+
MaxAllowed: 3,
309+
Allocated: 0,
310+
},
311+
},
312+
},
313+
},
314+
job: &podgroup_info.PodGroupInfo{
315+
Name: "job-a",
316+
Namespace: "team-a",
317+
Queue: "queue1",
318+
},
319+
requestedShare: rs.ResourceQuantities{
320+
rs.GpuResource: 4,
321+
},
322+
expectedResult: false,
323+
expectedDetails: &v2alpha2.QuotaDetails{
324+
QueueAllocatedResources: v1.ResourceList{
325+
v1.ResourceCPU: *resource.NewMilliQuantity(0, resource.DecimalSI),
326+
v1.ResourceMemory: *resource.NewQuantity(0, resource.DecimalSI),
327+
constants.GpuResource: *resource.NewQuantity(0, resource.DecimalSI),
328+
},
329+
},
330+
},
291331
"multiple queues - limited queues - results below limit": {
292332
queues: map[common_info.QueueID]*rs.QueueAttributes{
293333
"top-queue": {
@@ -401,6 +441,18 @@ var _ = Describe("Max Allowed Policy Check", func() {
401441
capacityPolicy := New(testData.queues, true)
402442
result := capacityPolicy.resultsOverLimit(testData.requestedShare, testData.job)
403443
Expect(result.IsSchedulable).To(Equal(testData.expectedResult))
444+
if testData.expectedDetails != nil {
445+
expectedValues := reflect.ValueOf(*testData.expectedDetails)
446+
resultValues := reflect.ValueOf(*result.Details.QueueDetails)
447+
for i := 0; i < expectedValues.NumField(); i++ {
448+
xField := expectedValues.Field(i).Interface()
449+
yField := resultValues.Field(i).Interface()
450+
zero := reflect.Zero(expectedValues.Field(i).Type()).Interface()
451+
if !reflect.DeepEqual(xField, zero) {
452+
Expect(xField).To(Equal(yField))
453+
}
454+
}
455+
}
404456
})
405457
}
406458

0 commit comments

Comments
 (0)