@@ -363,3 +363,130 @@ func TestDeletionAndStatusChange(t *testing.T) {
363
363
t .Errorf ("Status changed" )
364
364
}
365
365
}
366
+
367
+ func TestRemoveIneligibleDomainsSharding (t * testing.T ) {
368
+ api , _ , mockHstspreload , mockPreloadlist := mockAPI (0 * time .Second )
369
+
370
+ testPreloadlist := preloadlist.PreloadList {Entries : []preloadlist.Entry {
371
+ {Name : "a.test" , Mode : preloadlist .ForceHTTPS , IncludeSubDomains : true , Policy : preloadlist .Bulk1Year },
372
+ {Name : "n.test" , Mode : preloadlist .ForceHTTPS , IncludeSubDomains : true , Policy : preloadlist .Bulk1Year },
373
+ {Name : "z.test" , Mode : preloadlist .ForceHTTPS , IncludeSubDomains : true , Policy : preloadlist .Bulk1Year },
374
+ }}
375
+ testEligibleResponses := map [string ]hstspreload.Issues {
376
+ "a.test" : issuesWithErrors ,
377
+ "n.test" : issuesWithErrors ,
378
+ "z.test" : issuesWithErrors ,
379
+ }
380
+ mockPreloadlist .list = testPreloadlist
381
+ mockHstspreload .eligibleResponses = testEligibleResponses
382
+
383
+ w := httptest .NewRecorder ()
384
+ w .Body = & bytes.Buffer {}
385
+
386
+ r , err := http .NewRequest ("GET" , "" , nil )
387
+ if err != nil {
388
+ t .Fatalf ("[%s] %s" , "NewRequest Failed" , err )
389
+ }
390
+
391
+ api .Update (w , r )
392
+
393
+ // These test cases are structured to be run in this specific order and
394
+ // each case depends on the behavior of the previous ones.
395
+ tests := []struct {
396
+ name string
397
+ query string
398
+ expectedCounts map [string ]int
399
+ }{
400
+ {
401
+ // Start by running RemoveIneligibleDomains with no query
402
+ // parameters - it should process every domain.
403
+ "no range specified" ,
404
+ "" ,
405
+ map [string ]int {
406
+ "a.test" : 1 ,
407
+ "n.test" : 1 ,
408
+ "z.test" : 1 ,
409
+ },
410
+ },
411
+ {
412
+ // Specifying an end of "n" (the [start, end) interval is half-open)
413
+ // should result in only a.test being processed. Every time a domain
414
+ // is processed, the number of scans in its IneligibleDomainState
415
+ // increases.
416
+ "query range only has end" ,
417
+ "end=n" ,
418
+ map [string ]int {
419
+ "a.test" : 2 ,
420
+ "n.test" : 1 ,
421
+ "z.test" : 1 ,
422
+ },
423
+ },
424
+ {
425
+ // With an interval of ["n","z"), only n.test should match.
426
+ "start and end" ,
427
+ "start=n&end=z" ,
428
+ map [string ]int {
429
+ "a.test" : 2 ,
430
+ "n.test" : 2 ,
431
+ "z.test" : 1 ,
432
+ },
433
+ },
434
+ {
435
+ // A start of "z" with no end should only match z.test from the
436
+ // test preload list.
437
+ "only start" ,
438
+ "start=z" ,
439
+ map [string ]int {
440
+ "a.test" : 2 ,
441
+ "n.test" : 2 ,
442
+ "z.test" : 2 ,
443
+ },
444
+ },
445
+ {
446
+ // A bad range (start after end) does nothing.
447
+ "bad range" ,
448
+ "start=b&end=a" ,
449
+ map [string ]int {
450
+ "a.test" : 2 ,
451
+ "n.test" : 2 ,
452
+ "z.test" : 2 ,
453
+ },
454
+ },
455
+ }
456
+
457
+ for _ , test := range tests {
458
+ // Make request with the specified query
459
+ r , err := http .NewRequest ("GET" , "" , nil )
460
+ if err != nil {
461
+ t .Fatalf ("[%s] %s" , "NewRequest Failed" , err )
462
+ }
463
+ r = toAppEngineHttpRequest (r )
464
+ r .URL .RawQuery = test .query
465
+ w := httptest .NewRecorder ()
466
+ w .Body = & bytes.Buffer {}
467
+ api .RemoveIneligibleDomains (w , r )
468
+
469
+ // Look at the IneligibleDomainStates created or updated by
470
+ // RemoveIneligibleDomains and check that the number of scans for
471
+ // each domain matches the expected count.
472
+ states , err := api .database .GetAllIneligibleDomainStates ()
473
+ if err != nil {
474
+ t .Fatalf ("Couldn't get the states of all domains in the database." )
475
+ }
476
+ seenNames := make (map [string ]bool )
477
+ for _ , state := range states {
478
+ expectedCount , found := test .expectedCounts [state .Name ]
479
+ if ! found {
480
+ t .Errorf ("[%s] found unexpected domain %q in IneligibleDomainStates list" , test .name , state .Name )
481
+ continue
482
+ }
483
+ if len (state .Scans ) != expectedCount {
484
+ t .Errorf ("[%s] Unexpected number of scans for domain %q: got %d, want %d" , test .name , state .Name , len (state .Scans ), expectedCount )
485
+ }
486
+ seenNames [state .Name ] = true
487
+ }
488
+ if len (seenNames ) != len (test .expectedCounts ) {
489
+ t .Errorf ("[%s] Wrong number of IneligibleDomainStates: got %d, want %d" , test .name , len (seenNames ), len (test .expectedCounts ))
490
+ }
491
+ }
492
+ }
0 commit comments