Skip to content

[CIR] Add special type and new operations for vptrs #1745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 91 additions & 14 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2595,7 +2595,7 @@ def CIR_GetGlobalOp : CIR_Op<"get_global", [
// VTableAddrPointOp
//===----------------------------------------------------------------------===//

def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Get the vtable (global variable) address point";
Expand All @@ -2604,39 +2604,116 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
(address point) of a C++ virtual table. An object internal `__vptr`
gets initializated on top of the value returned by this operation.

`address_point.index` (vtable index) provides the appropriate vtable within the vtable group
(as specified by Itanium ABI), and `address_point.offset` (address point index) the actual address
point within that vtable.
`address_point.index` (vtable index) provides the appropriate vtable within
the vtable group (as specified by Itanium ABI), and `address_point.offset`
(address point index) the actual address point within that vtable.

The return type is always a `!cir.ptr<!cir.ptr<() -> i32>>`.
The return type is always `!cir.vptr`.

Example:
```mlir
cir.global linkonce_odr @_ZTV1B = ...
...
%3 = cir.vtable.address_point(@_ZTV1B, address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.ptr<() -> i32>>
%3 = cir.vtable.address_point(@_ZTV1B,
address_point = <index = 0, offset = 2>) : !cir.vptr
```
}];

let arguments = (ins
OptionalAttr<FlatSymbolRefAttr>:$name,
Optional<CIR_AnyType>:$sym_addr,
FlatSymbolRefAttr:$name,
CIR_AddressPointAttr:$address_point
);

let results = (outs Res<CIR_PointerType, "", []>:$addr);
let results = (outs Res<CIR_VPtrType, "", []>:$addr);

let assemblyFormat = [{
`(`
($name^)?
($sym_addr^ `:` type($sym_addr))?
`,`
`address_point` `=` $address_point
$name `,` `address_point` `=` $address_point
`)`
`:` qualified(type($addr)) attr-dict
}];
}

let hasVerifier = 1;
//===----------------------------------------------------------------------===//
// VTableGetVPtr
//===----------------------------------------------------------------------===//

def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
let summary = "Get a the address of the vtable pointer for an object";
let description = [{
The `vtable.get_vptr` operation retrieves the address of the vptr for a
C++ object. This operation requires that the object pointer points to
the start of a complete object. (TODO: Describe how we get that).
The vptr will always be at offset zero in the object, but this operation
is more explicit about what is being retrieved than a direct bitcast.

The return type is always `!cir.ptr<!cir.vptr>`.

Example:
```mlir
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
```
}];

let arguments = (ins
Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src
);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

let results = (outs CIR_PtrToVPtr:$result);

let assemblyFormat = [{
$src `:` qualified(type($src)) `->` qualified(type($result)) attr-dict
}];

}

//===----------------------------------------------------------------------===//
// VTableGetVirtualFnAddrOp
//===----------------------------------------------------------------------===//

def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [
Pure
]> {
let summary = "Get a the address of a virtual function pointer";
let description = [{
The `vtable.get_virtual_fn_addr` operation retrieves the address of a
virtual function pointer from an object's vtable (__vptr).
This is an abstraction to perform the basic pointer arithmetic to get
the address of the virtual function pointer, which can then be loaded and
called.
Comment on lines +2680 to +2684
Copy link
Collaborator

@xlauko xlauko Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe reference relation to !cir.vptr here.


The `vptr` operand must be a `!cir.ptr<!cir.vptr>` value, which would
have been returned by a previous call to `cir.vatble.get_vptr`. The
`index` operand is an index of the virtual function in the vtable.

The return type is a pointer-to-pointer to the function type.

Example:
```mlir
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
%4 = cir.load %3 : !cir.ptr<!cir.vptr>, !cir.vptr
%5 = cir.vtable.get_virtual_fn_addr %4[2] : !cir.vptr
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>>
%6 = cir.load align(8) %5 : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>)
-> !s32i>>>,
!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>
%7 = cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>,
!cir.ptr<!rec_C>) -> !s32i
```
}];

let arguments = (ins
Arg<CIR_VPtrType, "vptr", [MemRead]>:$vptr,
I64Attr:$index);

let results = (outs CIR_PointerType:$result);

let assemblyFormat = [{
$vptr `[` $index `]` attr-dict
`:` qualified(type($vptr)) `->` qualified(type($result))
}];
}

//===----------------------------------------------------------------------===//
Expand Down
10 changes: 9 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,21 @@ def CIR_PtrToExceptionInfoType
def CIR_AnyDataMemberType : CIR_TypeBase<"::cir::DataMemberType",
"data member type">;

//===----------------------------------------------------------------------===//
// VPtr type predicates
//===----------------------------------------------------------------------===//

def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType", "vptr type">;

def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>;

//===----------------------------------------------------------------------===//
// Scalar Type predicates
//===----------------------------------------------------------------------===//

defvar CIR_ScalarTypes = [
CIR_AnyBoolType, CIR_AnyIntType, CIR_AnyFloatType, CIR_AnyPtrType,
CIR_AnyDataMemberType
CIR_AnyDataMemberType, CIR_AnyVPtrType
];

def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
Expand Down
34 changes: 33 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,37 @@ def CIR_DataMemberType : CIR_Type<"DataMember", "data_member",
}];
}

//===----------------------------------------------------------------------===//
// CIR_VPtrType
//===----------------------------------------------------------------------===//

def CIR_VPtrType : CIR_Type<"VPtr", "vptr", [
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>
]> {

let summary = "CIR type that is used for the vptr member of C++ objects";
let description = [{
`cir.vptr` is a special type used as the type for the vptr member of a C++
object. This avoids using arbitrary pointer types to declare vptr values
and allows stronger type-based checking for operations that use or provide
access to the vptr.

This type will be the element type of the 'vptr' member of structures that
require a vtable pointer. A pointer to this type is returned by the
`cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this
pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to
get the address of a virtual function pointer.

The pointer may also be cast to other pointer types in order to perform
pointer arithmetic based on information encoded in the AST layout to get
the offset from a pointer to a dynamic object to the base object pointer,
the base object offset value from the vtable, or the type information
entry for an object.
TODO: We should have special operations to do that too.
}];
}


//===----------------------------------------------------------------------===//
// BoolType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -751,7 +782,8 @@ def CIRRecordType : Type<
def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_MethodType,
CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_FuncType, CIR_VoidType,
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType,
CIR_VPtrType
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
8 changes: 2 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,8 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
llvm_unreachable("unsupported long double format");
}

mlir::Type getVirtualFnPtrType(bool isVarArg = false) {
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
// type so it's a bit more clear and C++ idiomatic.
auto fnTy = cir::FuncType::get({}, getUInt32Ty(), isVarArg);
assert(!cir::MissingFeatures::isVarArg());
return getPointerTo(getPointerTo(fnTy));
mlir::Type getPtrToVPtrType() {
return getPointerTo(cir::VPtrType::get(getContext()));
}

cir::FuncType getFuncType(llvm::ArrayRef<mlir::Type> params, mlir::Type retTy,
Expand Down
12 changes: 9 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,10 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
//
// vtable field is derived from `this` pointer, therefore they should be in
// the same addr space.
// TODO(cir): We should be using cir.get_vptr rather than a bitcast to get
// the vptr field, but the call to ApplyNonVirtualAndVirtualOffset
// will also need to be adjusted. That should probably be using
// cir.base_class_addr.
assert(!cir::MissingFeatures::addressSpace());
VTableField = builder.createElementBitCast(loc, VTableField,
VTableAddressPoint.getType());
Expand Down Expand Up @@ -1704,10 +1708,12 @@ void CIRGenFunction::emitTypeMetadataCodeForVCall(const CXXRecordDecl *RD,
}

mlir::Value CIRGenFunction::getVTablePtr(mlir::Location Loc, Address This,
mlir::Type VTableTy,
const CXXRecordDecl *RD) {
Address VTablePtrSrc = builder.createElementBitCast(Loc, This, VTableTy);
auto VTable = builder.createLoad(Loc, VTablePtrSrc);
auto VTablePtr = builder.create<cir::VTableGetVPtrOp>(
Loc, builder.getPtrToVPtrType(), This.getPointer());
Address VTablePtrAddr = Address(VTablePtr, This.getAlignment());

auto VTable = builder.createLoad(Loc, VTablePtrAddr);
assert(!cir::MissingFeatures::tbaa());

if (CGM.getCodeGenOpts().OptimizationLevel > 0 &&
Expand Down
1 change: 0 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,6 @@ class CIRGenFunction : public CIRGenTypeCache {
VisitedVirtualBasesSetTy &VBases, VPtrsVector &vptrs);
/// Return the Value of the vtable pointer member pointed to by This.
mlir::Value getVTablePtr(mlir::Location Loc, Address This,
mlir::Type VTableTy,
const CXXRecordDecl *VTableClass);

/// Returns whether we should perform a type checked load when loading a
Expand Down
29 changes: 13 additions & 16 deletions clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,11 +935,11 @@ cir::GlobalOp CIRGenItaniumCXXABI::getAddrOfVTable(const CXXRecordDecl *RD,
CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
CIRGenFunction &CGF, GlobalDecl GD, Address This, mlir::Type Ty,
SourceLocation Loc) {
auto &builder = CGM.getBuilder();
auto loc = CGF.getLoc(Loc);
auto TyPtr = CGF.getBuilder().getPointerTo(Ty);
auto TyPtr = builder.getPointerTo(Ty);
auto *MethodDecl = cast<CXXMethodDecl>(GD.getDecl());
auto VTable = CGF.getVTablePtr(
loc, This, CGF.getBuilder().getPointerTo(TyPtr), MethodDecl->getParent());
auto VTable = CGF.getVTablePtr(loc, This, MethodDecl->getParent());

uint64_t VTableIndex = CGM.getItaniumVTableContext().getMethodVTableIndex(GD);
mlir::Value VFunc{};
Expand All @@ -952,15 +952,10 @@ CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
if (CGM.getItaniumVTableContext().isRelativeLayout()) {
llvm_unreachable("NYI");
} else {
VTable = CGF.getBuilder().createBitcast(
loc, VTable, CGF.getBuilder().getPointerTo(TyPtr));
auto VTableSlotPtr = CGF.getBuilder().create<cir::VTableAddrPointOp>(
loc, CGF.getBuilder().getPointerTo(TyPtr),
::mlir::FlatSymbolRefAttr{}, VTable,
cir::AddressPointAttr::get(CGF.getBuilder().getContext(), 0,
VTableIndex));
VFuncLoad = CGF.getBuilder().createAlignedLoad(loc, TyPtr, VTableSlotPtr,
CGF.getPointerAlign());
auto VTableSlotPtr = builder.create<cir::VTableGetVirtualFnAddrOp>(
loc, builder.getPointerTo(TyPtr), VTable, VTableIndex);
VFuncLoad = builder.createAlignedLoad(loc, TyPtr, VTableSlotPtr,
CGF.getPointerAlign());
}

// Add !invariant.load md to virtual function load to indicate that
Expand Down Expand Up @@ -1014,11 +1009,11 @@ CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject Base,
.getAddressPoint(Base);

auto &builder = CGM.getBuilder();
auto vtablePtrTy = builder.getVirtualFnPtrType(/*isVarArg=*/false);
auto vtablePtrTy = cir::VPtrType::get(builder.getContext());

return builder.create<cir::VTableAddrPointOp>(
CGM.getLoc(VTableClass->getSourceRange()), vtablePtrTy,
mlir::FlatSymbolRefAttr::get(vtable.getSymNameAttr()), mlir::Value{},
mlir::FlatSymbolRefAttr::get(vtable.getSymNameAttr()),
cir::AddressPointAttr::get(CGM.getBuilder().getContext(),
AddressPoint.VTableIndex,
AddressPoint.AddressPointIndex));
Expand Down Expand Up @@ -2410,14 +2405,16 @@ void CIRGenItaniumCXXABI::emitThrow(CIRGenFunction &CGF,
mlir::Value CIRGenItaniumCXXABI::getVirtualBaseClassOffset(
mlir::Location loc, CIRGenFunction &CGF, Address This,
const CXXRecordDecl *ClassDecl, const CXXRecordDecl *BaseClassDecl) {
auto VTablePtr = CGF.getVTablePtr(loc, This, CGM.UInt8PtrTy, ClassDecl);
auto VTablePtr = CGF.getVTablePtr(loc, This, ClassDecl);
auto VTableBytePtr =
CGF.getBuilder().createBitcast(VTablePtr, CGM.UInt8PtrTy);
CharUnits VBaseOffsetOffset =
CGM.getItaniumVTableContext().getVirtualBaseOffsetOffset(ClassDecl,
BaseClassDecl);
mlir::Value OffsetVal =
CGF.getBuilder().getSInt64(VBaseOffsetOffset.getQuantity(), loc);
auto VBaseOffsetPtr = CGF.getBuilder().create<cir::PtrStrideOp>(
loc, VTablePtr.getType(), VTablePtr,
loc, CGM.UInt8PtrTy, VTableBytePtr,
OffsetVal); // vbase.offset.ptr

mlir::Value VBaseOffset;
Expand Down
4 changes: 1 addition & 3 deletions clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,7 @@ void CIRRecordLowering::accumulateVPtrs() {
}

mlir::Type CIRRecordLowering::getVFPtrType() {
// FIXME: replay LLVM codegen for now, perhaps add a vtable ptr special
// type so it's a bit more clear and C++ idiomatic.
return builder.getVirtualFnPtrType();
return cir::VPtrType::get(builder.getContext());
}

void CIRRecordLowering::fillOutputFields() {
Expand Down
32 changes: 7 additions & 25 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,12 @@ LogicalResult cir::CastOp::verify() {
return success();
}

// Allow casting cir.vptr to pointer types.
// TODO: Add operations to get object offset and type info and remove this.
if (mlir::isa<cir::VPtrType>(srcType) &&
mlir::dyn_cast<cir::PointerType>(resType))
return success();

// Handle the data member pointer types.
if (mlir::isa<cir::DataMemberType>(srcType) &&
mlir::isa<cir::DataMemberType>(resType))
Expand Down Expand Up @@ -2390,10 +2396,7 @@ cir::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

LogicalResult
cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// vtable ptr is not coming from a symbol.
if (!getName())
return success();
auto name = *getName();
StringRef name = getName();

// Verify that the result type underlying pointer type matches the type of
// the referenced cir.global or cir.func op.
Expand All @@ -2411,27 +2414,6 @@ cir::VTableAddrPointOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

LogicalResult cir::VTableAddrPointOp::verify() {
// The operation uses either a symbol or a value to operate, but not both
if (getName() && getSymAddr())
return emitOpError("should use either a symbol or value, but not both");

// If not a symbol, stick with the concrete type used for getSymAddr.
if (getSymAddr())
return success();

auto resultType = getAddr().getType();
auto intTy = cir::IntType::get(getContext(), 32, /*isSigned=*/false);
auto fnTy = cir::FuncType::get({}, intTy);

auto resTy = cir::PointerType::get(cir::PointerType::get(fnTy));

if (resultType != resTy)
return emitOpError("result type must be '")
<< resTy << "', but provided result type is '" << resultType << "'";
return success();
}

//===----------------------------------------------------------------------===//
// VTTAddrPointOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading