Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/lython/dialects/cpp/PyVerifier/ClassFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,20 @@ static bool isSupportedStaticFieldType(Type type) {
if (isa<IntType, FloatType, BoolType, StrType, NoneType, IntegerType,
mlir::FloatType>(type))
return true;
if (auto listType = dyn_cast<ListType>(type)) {
Type elementType = listType.getElementType();
auto isSupportedContainerElement = [](Type elementType) {
return isa<ClassType, IntType, FloatType, BoolType, StrType, NoneType,
ObjectType>(elementType);
};
if (auto listType = dyn_cast<ListType>(type)) {
return isSupportedContainerElement(listType.getElementType());
}
if (auto dictType = dyn_cast<DictType>(type)) {
return isSupportedContainerElement(dictType.getKeyType()) &&
isSupportedContainerElement(dictType.getValueType());
}
if (auto tupleType = dyn_cast<TupleType>(type)) {
return llvm::all_of(tupleType.getElementTypes(),
isSupportedContainerElement);
}
return false;
}
Expand Down Expand Up @@ -47,7 +57,8 @@ static LogicalResult verifyClassFieldSchema(ClassOp op) {
return op.emitOpError("unsupported static field type ")
<< typeAttr.getValue()
<< "; supported field types are !py.int, !py.float, !py.bool, "
"!py.none, integers, and floats";
"!py.str, !py.none, integers, floats, typed lists, and typed "
"dicts/tuples";
}

return success();
Expand Down
126 changes: 124 additions & 2 deletions src/lython/lowering/Common/LoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"

namespace py {
Expand Down Expand Up @@ -38,9 +39,130 @@ getOrInsertLLVMFunc(mlir::Location loc, mlir::ModuleOp module,
return builder.create<mlir::LLVM::LLVMFuncOp>(loc, name, fnType);
}

enum class TypedContainerSlotPolicy {
Unsupported,
NativeInteger,
NativeBool,
NativeFloat,
PointerBits,
};

inline TypedContainerSlotPolicy getTypedContainerSlotPolicy(mlir::Type type) {
if (mlir::isa<IntType>(type))
return TypedContainerSlotPolicy::NativeInteger;
if (mlir::isa<BoolType>(type))
return TypedContainerSlotPolicy::NativeBool;
if (mlir::isa<FloatType>(type))
return TypedContainerSlotPolicy::NativeFloat;
if (mlir::isa<StrType, ObjectType, ClassType>(type))
return TypedContainerSlotPolicy::PointerBits;
return TypedContainerSlotPolicy::Unsupported;
}

inline bool isTypedContainerSlotSupported(mlir::Type type) {
return getTypedContainerSlotPolicy(type) !=
TypedContainerSlotPolicy::Unsupported;
}

inline bool usesPackedI64BootstrapSlot(mlir::Type type) {
switch (getTypedContainerSlotPolicy(type)) {
case TypedContainerSlotPolicy::NativeInteger:
case TypedContainerSlotPolicy::NativeBool:
case TypedContainerSlotPolicy::NativeFloat:
case TypedContainerSlotPolicy::PointerBits:
return true;
case TypedContainerSlotPolicy::Unsupported:
return false;
}
return false;
}

inline mlir::Type getTypedContainerElementStorageType(mlir::Type logicalType,
mlir::MLIRContext *ctx) {
switch (getTypedContainerSlotPolicy(logicalType)) {
case TypedContainerSlotPolicy::NativeInteger:
return mlir::IntegerType::get(ctx, 64);
case TypedContainerSlotPolicy::NativeBool:
return mlir::IntegerType::get(ctx, 8);
case TypedContainerSlotPolicy::NativeFloat:
return mlir::Float64Type::get(ctx);
case TypedContainerSlotPolicy::PointerBits:
return mlir::IntegerType::get(ctx, 64);
case TypedContainerSlotPolicy::Unsupported:
return {};
}
return {};
}

inline mlir::MemRefType getListHeaderMemRefType(mlir::MLIRContext *ctx) {
return mlir::MemRefType::get({4}, mlir::IntegerType::get(ctx, 64));
}

inline mlir::MemRefType getListItemsMemRefType(mlir::Type elementType,
mlir::MLIRContext *ctx) {
mlir::Type storageType =
getTypedContainerElementStorageType(elementType, ctx);
if (!storageType)
return {};
return mlir::MemRefType::get({mlir::ShapedType::kDynamic}, storageType);
}

inline mlir::MemRefType getTupleHeaderMemRefType(mlir::MLIRContext *ctx) {
return mlir::MemRefType::get({3}, mlir::IntegerType::get(ctx, 64));
}

inline mlir::Type getTupleItemsStorageType(TupleType tupleType,
mlir::MLIRContext *ctx) {
auto elementTypes = tupleType.getElementTypes();
if (elementTypes.empty())
return mlir::IntegerType::get(ctx, 64);
mlir::Type firstStorage =
getTypedContainerElementStorageType(elementTypes.front(), ctx);
if (!firstStorage)
return mlir::IntegerType::get(ctx, 64);
for (mlir::Type elementType : elementTypes.drop_front()) {
mlir::Type storage = getTypedContainerElementStorageType(elementType, ctx);
if (storage != firstStorage)
return mlir::IntegerType::get(ctx, 64);
}
return firstStorage;
}

inline mlir::MemRefType getTupleItemsMemRefType(TupleType tupleType,
mlir::MLIRContext *ctx) {
return mlir::MemRefType::get({mlir::ShapedType::kDynamic},
getTupleItemsStorageType(tupleType, ctx));
}

inline mlir::MemRefType getDictHeaderMemRefType(mlir::MLIRContext *ctx) {
return mlir::MemRefType::get({5}, mlir::IntegerType::get(ctx, 64));
}

inline mlir::MemRefType getDictKeysMemRefType(DictType dictType,
mlir::MLIRContext *ctx) {
mlir::Type storageType =
getTypedContainerElementStorageType(dictType.getKeyType(), ctx);
if (!storageType)
return {};
return mlir::MemRefType::get({mlir::ShapedType::kDynamic}, storageType);
}

inline mlir::MemRefType getDictValuesMemRefType(DictType dictType,
mlir::MLIRContext *ctx) {
mlir::Type storageType =
getTypedContainerElementStorageType(dictType.getValueType(), ctx);
if (!storageType)
return {};
return mlir::MemRefType::get({mlir::ShapedType::kDynamic}, storageType);
}

inline mlir::MemRefType getDictStatesMemRefType(mlir::MLIRContext *ctx) {
return mlir::MemRefType::get({mlir::ShapedType::kDynamic},
mlir::IntegerType::get(ctx, 8));
}

inline bool isMemRefSlotCompatibleScalarType(mlir::Type type) {
return mlir::isa<IntType, BoolType, FloatType, StrType, ObjectType,
ClassType>(type);
return isTypedContainerSlotSupported(type);
}

inline bool isCompilerOwnedMemRefListType(mlir::Type type) {
Expand Down
84 changes: 57 additions & 27 deletions src/lython/lowering/Common/RuntimeSupport.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "Common/RuntimeSupport.h"

#include "Common/LoweringUtils.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -36,47 +38,75 @@ PyLLVMTypeConverter::PyLLVMTypeConverter(mlir::MLIRContext *ctx)
return std::nullopt;
});

addConversion([ctx](ListType listType) -> std::optional<mlir::Type> {
mlir::Type elementType = listType.getElementType();
if (mlir::isa<IntType, BoolType, FloatType, ClassType>(elementType))
return mlir::MemRefType::get({mlir::ShapedType::kDynamic},
mlir::IntegerType::get(ctx, 64));
return std::nullopt;
});

addConversion([ctx](TupleType tupleType) -> std::optional<mlir::Type> {
(void)tupleType;
return mlir::MemRefType::get({mlir::ShapedType::kDynamic},
mlir::IntegerType::get(ctx, 64));
});

addConversion([ctx](DictType dictType) -> std::optional<mlir::Type> {
auto isSlotType = [](mlir::Type type) {
return mlir::isa<IntType, BoolType, FloatType, StrType, ObjectType,
ClassType>(type);
};
if (isSlotType(dictType.getKeyType()) &&
isSlotType(dictType.getValueType()))
return mlir::MemRefType::get({mlir::ShapedType::kDynamic},
mlir::IntegerType::get(ctx, 64));
return std::nullopt;
});
addConversion(
[ctx](ListType listType, mlir::SmallVectorImpl<mlir::Type> &results)
-> std::optional<mlir::LogicalResult> {
auto itemsType = getListItemsMemRefType(listType.getElementType(), ctx);
if (!itemsType)
return std::nullopt;
results.push_back(getListHeaderMemRefType(ctx));
results.push_back(itemsType);
return mlir::success();
});

addConversion(
[ctx](TupleType tupleType, mlir::SmallVectorImpl<mlir::Type> &results)
-> std::optional<mlir::LogicalResult> {
results.push_back(getTupleHeaderMemRefType(ctx));
results.push_back(getTupleItemsMemRefType(tupleType, ctx));
return mlir::success();
});

addConversion(
[ctx](DictType dictType, mlir::SmallVectorImpl<mlir::Type> &results)
-> std::optional<mlir::LogicalResult> {
if (isTypedContainerSlotSupported(dictType.getKeyType()) &&
isTypedContainerSlotSupported(dictType.getValueType())) {
results.push_back(getDictHeaderMemRefType(ctx));
results.push_back(getDictKeysMemRefType(dictType, ctx));
results.push_back(getDictValuesMemRefType(dictType, ctx));
results.push_back(getDictStatesMemRefType(ctx));
return mlir::success();
}
return std::nullopt;
});

auto materializeBridge = [](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs,
mlir::Location loc) -> mlir::Value {
if (inputs.size() != 1)
if (inputs.empty())
return {};
mlir::Type inputType = inputs.front().getType();
if (!isPyRuntimeBridgeType(resultType) && !isPyRuntimeBridgeType(inputType))
bool inputIsPyBridge =
isPyRuntimeBridgeType(inputType) ||
llvm::all_of(inputs, [](mlir::Value input) {
return mlir::isa<mlir::MemRefType>(input.getType());
});
if (!isPyRuntimeBridgeType(resultType) && !inputIsPyBridge)
return {};
return builder
.create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};

auto materializeTargetBridge =
[](mlir::OpBuilder &builder, mlir::TypeRange resultTypes,
mlir::ValueRange inputs, mlir::Location loc,
mlir::Type originalType) -> mlir::SmallVector<mlir::Value> {
if (inputs.size() != 1 || resultTypes.empty())
return {};
if (!isPyRuntimeBridgeType(originalType))
return {};
auto cast = builder.create<mlir::UnrealizedConversionCastOp>(
loc, resultTypes, inputs);
mlir::SmallVector<mlir::Value> results;
results.append(cast.getResults().begin(), cast.getResults().end());
return results;
};

addSourceMaterialization(materializeBridge);
addTargetMaterialization(materializeBridge);
addTargetMaterialization(materializeTargetBridge);
}

RuntimeAPI::RuntimeAPI(mlir::ModuleOp module,
Expand Down
Loading
Loading