Skip to content

Commit 83825dd

Browse files
committed
Auto merge of #143784 - scottmcm:enums-again-new-ex2, r=dianqk
Simplify discriminant codegen for niche-encoded variants which don't wrap across an integer boundary Inspired by #139729, this attempts to be a much-simpler and more-localized change while still making a difference. (Specifically, this does not try to solve the problem with select-sinking, leaving that to be fixed by llvm/llvm-project#134024 -- once it gets released -- instead of in rustc's codegen.) What this *does* improve is checking for the variant in a 3+ variant enum when that variant is the type providing the niche. Something like `if let Foo::WithBool(_) = ...` previously compiled to `ugt(add(x, -2), 2)`, which is non-trivial to think about because it's depending on the unsigned wrapping to shift the 0/1 up above 2. With this PR it compiles to just `ult(x, 2)`, which is probably what you'd have written yourself if you were doing it by hand to look for "is this byte a bool?". That's done by leaving most of the codegen alone, but adding a couple new special cases to the `is_niche` check. The default looks at the relative discriminant, but in the common cases where there's no wraparound involved, we can just check the original value, rather than the offsetted one. The first commit just adds some tests, so the best way to see the effect of this change is to look at the second commit and how it updates the test expectations.
2 parents 1079c5e + 4fa23d9 commit 83825dd

File tree

4 files changed

+669
-40
lines changed

4 files changed

+669
-40
lines changed

compiler/rustc_abi/src/lib.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use std::fmt;
4343
#[cfg(feature = "nightly")]
4444
use std::iter::Step;
4545
use std::num::{NonZeroUsize, ParseIntError};
46-
use std::ops::{Add, AddAssign, Deref, Mul, RangeInclusive, Sub};
46+
use std::ops::{Add, AddAssign, Deref, Mul, RangeFull, RangeInclusive, Sub};
4747
use std::str::FromStr;
4848

4949
use bitflags::bitflags;
@@ -1391,12 +1391,45 @@ impl WrappingRange {
13911391
}
13921392

13931393
/// Returns `true` if `size` completely fills the range.
1394+
///
1395+
/// Note that this is *not* the same as `self == WrappingRange::full(size)`.
1396+
/// Niche calculations can produce full ranges which are not the canonical one;
1397+
/// for example `Option<NonZero<u16>>` gets `valid_range: (..=0) | (1..)`.
13941398
#[inline]
13951399
fn is_full_for(&self, size: Size) -> bool {
13961400
let max_value = size.unsigned_int_max();
13971401
debug_assert!(self.start <= max_value && self.end <= max_value);
13981402
self.start == (self.end.wrapping_add(1) & max_value)
13991403
}
1404+
1405+
/// Checks whether this range is considered non-wrapping when the values are
1406+
/// interpreted as *unsigned* numbers of width `size`.
1407+
///
1408+
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
1409+
/// and `Err(..)` if the range is full so it depends how you think about it.
1410+
#[inline]
1411+
pub fn no_unsigned_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
1412+
if self.is_full_for(size) { Err(..) } else { Ok(self.start <= self.end) }
1413+
}
1414+
1415+
/// Checks whether this range is considered non-wrapping when the values are
1416+
/// interpreted as *signed* numbers of width `size`.
1417+
///
1418+
/// This is heavily dependent on the `size`, as `100..=200` does wrap when
1419+
/// interpreted as `i8`, but doesn't when interpreted as `i16`.
1420+
///
1421+
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
1422+
/// and `Err(..)` if the range is full so it depends how you think about it.
1423+
#[inline]
1424+
pub fn no_signed_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
1425+
if self.is_full_for(size) {
1426+
Err(..)
1427+
} else {
1428+
let start: i128 = size.sign_extend(self.start);
1429+
let end: i128 = size.sign_extend(self.end);
1430+
Ok(start <= end)
1431+
}
1432+
}
14001433
}
14011434

14021435
impl fmt::Debug for WrappingRange {

compiler/rustc_codegen_ssa/src/mir/operand.rs

Lines changed: 96 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
486486
// value and the variant index match, since that's all `Niche` can encode.
487487

488488
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
489+
let niche_start_const = bx.cx().const_uint_big(tag_llty, niche_start);
489490

490491
// We have a subrange `niche_start..=niche_end` inside `range`.
491492
// If the value of the tag is inside this subrange, it's a
@@ -511,35 +512,88 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
511512
// } else {
512513
// untagged_variant
513514
// }
514-
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
515-
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
515+
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start_const);
516516
let tagged_discr =
517517
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
518518
(is_niche, tagged_discr, 0)
519519
} else {
520-
// The special cases don't apply, so we'll have to go with
521-
// the general algorithm.
522-
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
520+
// With multiple niched variants we'll have to actually compute
521+
// the variant index from the stored tag.
522+
//
523+
// However, there's still one small optimization we can often do for
524+
// determining *whether* a tag value is a natural value or a niched
525+
// variant. The general algorithm involves a subtraction that often
526+
// wraps in practice, making it tricky to analyse. However, in cases
527+
// where there are few enough possible values of the tag that it doesn't
528+
// need to wrap around, we can instead just look for the contiguous
529+
// tag values on the end of the range with a single comparison.
530+
//
531+
// For example, take the type `enum Demo { A, B, Untagged(bool) }`.
532+
// The `bool` is {0, 1}, and the two other variants are given the
533+
// tags {2, 3} respectively. That means the `tag_range` is
534+
// `[0, 3]`, which doesn't wrap as unsigned (nor as signed), so
535+
// we can test for the niched variants with just `>= 2`.
536+
//
537+
// That means we're looking either for the niche values *above*
538+
// the natural values of the untagged variant:
539+
//
540+
// niche_start niche_end
541+
// | |
542+
// v v
543+
// MIN -------------+---------------------------+---------- MAX
544+
// ^ | is niche |
545+
// | +---------------------------+
546+
// | |
547+
// tag_range.start tag_range.end
548+
//
549+
// Or *below* the natural values:
550+
//
551+
// niche_start niche_end
552+
// | |
553+
// v v
554+
// MIN ----+-----------------------+---------------------- MAX
555+
// | is niche | ^
556+
// +-----------------------+ |
557+
// | |
558+
// tag_range.start tag_range.end
559+
//
560+
// With those two options and having the flexibility to choose
561+
// between a signed or unsigned comparison on the tag, that
562+
// covers most realistic scenarios. The tests have a (contrived)
563+
// example of a 1-byte enum with over 128 niched variants which
564+
// wraps both as signed as unsigned, though, and for something
565+
// like that we're stuck with the general algorithm.
566+
567+
let tag_range = tag_scalar.valid_range(&dl);
568+
let tag_size = tag_scalar.size(&dl);
569+
let niche_end = u128::from(relative_max).wrapping_add(niche_start);
570+
let niche_end = tag_size.truncate(niche_end);
571+
572+
let relative_discr = bx.sub(tag, niche_start_const);
523573
let cast_tag = bx.intcast(relative_discr, cast_to, false);
524-
let is_niche = bx.icmp(
525-
IntPredicate::IntULE,
526-
relative_discr,
527-
bx.cx().const_uint(tag_llty, relative_max as u64),
528-
);
529-
530-
// Thanks to parameter attributes and load metadata, LLVM already knows
531-
// the general valid range of the tag. It's possible, though, for there
532-
// to be an impossible value *in the middle*, which those ranges don't
533-
// communicate, so it's worth an `assume` to let the optimizer know.
534-
if niche_variants.contains(&untagged_variant)
535-
&& bx.cx().sess().opts.optimize != OptLevel::No
536-
{
537-
let impossible =
538-
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
539-
let impossible = bx.cx().const_uint(tag_llty, impossible);
540-
let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible);
541-
bx.assume(ne);
542-
}
574+
let is_niche = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) {
575+
if niche_start == tag_range.start {
576+
let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
577+
bx.icmp(IntPredicate::IntULE, tag, niche_end_const)
578+
} else {
579+
assert_eq!(niche_end, tag_range.end);
580+
bx.icmp(IntPredicate::IntUGE, tag, niche_start_const)
581+
}
582+
} else if tag_range.no_signed_wraparound(tag_size) == Ok(true) {
583+
if niche_start == tag_range.start {
584+
let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
585+
bx.icmp(IntPredicate::IntSLE, tag, niche_end_const)
586+
} else {
587+
assert_eq!(niche_end, tag_range.end);
588+
bx.icmp(IntPredicate::IntSGE, tag, niche_start_const)
589+
}
590+
} else {
591+
bx.icmp(
592+
IntPredicate::IntULE,
593+
relative_discr,
594+
bx.cx().const_uint(tag_llty, relative_max as u64),
595+
)
596+
};
543597

544598
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
545599
};
@@ -550,11 +604,24 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
550604
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
551605
};
552606

553-
let discr = bx.select(
554-
is_niche,
555-
tagged_discr,
556-
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
557-
);
607+
let untagged_variant_const =
608+
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));
609+
610+
// Thanks to parameter attributes and load metadata, LLVM already knows
611+
// the general valid range of the tag. It's possible, though, for there
612+
// to be an impossible value *in the middle*, which those ranges don't
613+
// communicate, so it's worth an `assume` to let the optimizer know.
614+
// Most importantly, this means when optimizing a variant test like
615+
// `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
616+
// to `!is_niche` because the `complex` part can't possibly match.
617+
if niche_variants.contains(&untagged_variant)
618+
&& bx.cx().sess().opts.optimize != OptLevel::No
619+
{
620+
let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const);
621+
bx.assume(ne);
622+
}
623+
624+
let discr = bx.select(is_niche, tagged_discr, untagged_variant_const);
558625

559626
// In principle we could insert assumes on the possible range of `discr`, but
560627
// currently in LLVM this isn't worth it because the original `tag` will

0 commit comments

Comments
 (0)