diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h index 324380fd388f..e796fb38a74b 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h @@ -97,6 +97,12 @@ class QuantizedType : public Type { return -getDefaultMaximumForF8E5M2(); } + static constexpr int64_t getDefaultMaximumForF4E2M1FN() { return 6; } + + static constexpr int64_t getDefaultMinimumForF4E2M1FN() { + return -getDefaultMaximumForF4E2M1FN(); + } + /// Gets the original expressed type that this quantized type approximates. /// Note that this presumes that the quantized type was always derived from /// a floating point type, which in the broadest definition, is not true (i.e. @@ -267,7 +273,7 @@ class AnyQuantizedType /// Per-layer, optional parameters omitted: /// !quant /// -/// StorageType: 'i'|'u' NumBits +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// Scale: A legal double value /// ZeroPoint: An integer value @@ -327,7 +333,7 @@ class UniformQuantizedType /// Per-axis, optional parameters omitted: /// !quant /// -/// StorageType: 'i'|'u' NumBits +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// QuantizedDim: An integer value /// QuantParams: (Scale ':' ZeroPoint)+ @@ -414,7 +420,7 @@ class UniformQuantizedPerAxisType /// ScaleZeroList ::= ScaleZero (',' ScaleZero)* /// ScaleZero ::= Scale (':' ZeroPoint)? /// -/// StorageType: 'i'|'u' NumBits +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// AxisSpec: An integer value /// BlockSizeSpec: An integer value @@ -533,7 +539,7 @@ class UniformQuantizedSubChannelType /// QuantileQuantizedType derives from UniformQuantizedType and adds to it a /// look up table array of quantile values. The type of the data in the look up table is determined by -/// the quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64. +/// the quantileType member: supported quantileType types are integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64. /// /// Syntax synopsis: /// Per-layer, all parameters expressed: @@ -541,8 +547,8 @@ class UniformQuantizedSubChannelType /// Per-layer, optional parameters omitted: /// !quant /// -/// StorageType: 'i'|'u' NumBits -/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64' +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' +/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// Quantiles: Quantile+ /// Quantile: A legal double value @@ -600,7 +606,7 @@ class QuantileQuantizedType /// Represents per-axis QuantileQuantizedType (also known as per-channel /// quantization). The type of the data in the look up table is determined by the -/// quantileType member: supported quantileType types are integer/unsigned/hf8/bf8/f16/bf16/f32/f64. +/// quantileType member: supported quantileType types are integer/unsigned/f4/hf8/bf8/f16/bf16/f32/f64. /// /// Syntax synopsis: /// Per-axis, all parameters expressed: @@ -608,8 +614,8 @@ class QuantileQuantizedType /// Per-axis, optional parameters omitted: /// !quant /// -/// StorageType: 'i'|'u' NumBits -/// QuantileType: 'i'|'u' NumBits, 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64' +/// StorageType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8' +/// QuantileType: 'i'|'u' NumBits, 'f4', 'hf8', 'bf8', 'f16', 'bf16', 'f32', 'f64' /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' /// QuantizedDim: An integer value /// Quantiles: Quantile+ diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 9034080a6503..c4f2e21430f2 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -67,15 +67,18 @@ QuantizedType::verifyInvariants(function_ref emitError, const auto width = llvm::dyn_cast(storageType).getWidth(); defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width); defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width); - } else if (storageType.isa()) { + } else if (mlir::isa(storageType)) { defaultMin = QuantizedType::getDefaultMinimumForF8E5M2(); defaultMax = QuantizedType::getDefaultMaximumForF8E5M2(); - } else if (storageType.isa()) { + } else if (mlir::isa(storageType)) { defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN(); defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN(); + } else if (mlir::isa(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN(); + defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN(); } else { return emitError() << "illegal storage type, supported types are: integral " - "types, Float8E4M3FNType and Float8E5M2Type "; + "types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType "; } // Verify storageTypeMin and storageTypeMax. @@ -574,19 +577,18 @@ LogicalResult QuantileQuantizedType::verifyInvariants( unsigned typeWidth{}; if (storageType.isa()) { typeWidth = llvm::dyn_cast(storageType).getWidth(); - } else if (storageType.isa() || - storageType.isa()) { - // Both Float8E5M2Type and Float8E4M3FNType derive from FloatType. + } else if (mlir::isa(storageType)) { + // Float8E5M2Type, Float8E4M3FNType and Float4E2M1FNType derive from FloatType. typeWidth = llvm::dyn_cast(storageType).getWidth(); } else { return emitError() << "illegal storage type, supported types are: integral " - "types, Float8E4M3FNType and Float8E5M2Type "; + "types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType "; } const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1; const size_t typeWidthSize = 1 << typeWidth; const size_t expectedSize = - (storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize; + (storageTypeRange < typeWidthSize) && !mlir::isa(storageType) ? storageTypeRange : typeWidthSize; const auto quantileArraySize = quantiles.size(); if (quantileArraySize != expectedSize) { @@ -660,19 +662,18 @@ LogicalResult QuantileQuantizedPerAxisType::verifyInvariants( unsigned typeWidth{}; if (storageType.isa()) { typeWidth = llvm::dyn_cast(storageType).getWidth(); - } else if (storageType.isa() || - storageType.isa()) { - // Both Float8E5M2Type and Float8E4M3FNType derive from FloatType. + } else if (mlir::isa(storageType)) { + // Float8E5M2Type, Float8E4M3FNType and Float4E2M1FNType derive from FloatType. typeWidth = llvm::dyn_cast(storageType).getWidth(); } else { return emitError() << "illegal storage type, supported types are: integral " - "types, Float8E4M3FNType and Float8E5M2Type "; + "types, Float8E4M3FNType, Float8E5M2Type and Float4E2M1FNType "; } const size_t storageTypeRange = storageTypeMax - storageTypeMin + 1; const size_t typeWidthSize = 1 << typeWidth; const size_t expectedSize = - (storageTypeRange < typeWidthSize) ? storageTypeRange : typeWidthSize; + (storageTypeRange < typeWidthSize) && !mlir::isa(storageType) ? storageTypeRange : typeWidthSize; const auto quantileArraySize = quantiles.size(); if (quantileArraySize != expectedSize) { diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 50405da981e3..fa7a99e9b32b 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -35,9 +35,8 @@ static Type parseStorageType(DialectAsmParser &parser, bool &isSigned) { if (auto intType = llvm::dyn_cast(type)) { isSigned = !intType.isUnsigned(); storageTypeWidth = intType.getWidth(); - } else if (llvm::dyn_cast(type) || - llvm::dyn_cast(type)) { - storageTypeWidth = 8; + } else if (mlir::isa(type)) { + storageTypeWidth = llvm::dyn_cast(type).getWidth(); isSigned = true; } else { parser.emitError(typeLoc, "illegal quantized storage type alias"); @@ -132,12 +131,15 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, const auto width = llvm::dyn_cast(storageType).getWidth(); defaultMin = QuantizedType::getDefaultMinimumForInteger(isSigned, width); defaultMax = QuantizedType::getDefaultMaximumForInteger(isSigned, width); - } else if (storageType.isa()) { + } else if (mlir::isa(storageType)) { defaultMin = QuantizedType::getDefaultMinimumForF8E5M2(); defaultMax = QuantizedType::getDefaultMaximumForF8E5M2(); - } else if (storageType.isa()) { + } else if (mlir::isa(storageType)) { defaultMin = QuantizedType::getDefaultMinimumForF8E4M3FN(); defaultMax = QuantizedType::getDefaultMaximumForF8E4M3FN(); + } else if (mlir::isa(storageType)) { + defaultMin = QuantizedType::getDefaultMinimumForF4E2M1FN(); + defaultMax = QuantizedType::getDefaultMaximumForF4E2M1FN(); } else { defaultMin = std::numeric_limits::max(); defaultMax = std::numeric_limits::min(); @@ -150,7 +152,7 @@ static ParseResult parseStorageRange(DialectAsmParser &parser, Type storageType, } // Explicit storage min and storage max. - // F8 min and max values are integers, so parseInteger() is used. + // F8 and F4 min and max values are integers, so parseInteger() is used. SMLoc minLoc = parser.getCurrentLocation(), maxLoc; if (parser.parseInteger(storageTypeMin) || parser.parseColon() || parser.getCurrentLocation(&maxLoc) || @@ -382,7 +384,7 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, /// block-size-info `,` scale-zero-tensor `>` /// storage-spec ::= storage-type (`<` storage-range `>`)? /// storage-range ::= integer-literal `:` integer-literal -/// storage-type ::= (`i` | `u`) integer-literal +/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` | `f4E2M1FN` /// expressed-type-spec ::= `:` `f` integer-literal /// axis-spec ::= `:` integer-literal /// scale-zero ::= scale (`:` zero-point)? @@ -407,9 +409,9 @@ parseQuantParamListUntilRBrace(DialectAsmParser &parser, Type expressedType, /// scale-zero-list `>` /// storage-spec ::= storage-type (`<` storage-range `>`)? /// storage-range ::= integer-literal `:` integer-literal -/// storage-type ::= (`i` | `u`) integer-literal +/// storage-type ::= (`i` | `u`) integer-literal | `f8E5M2` | `f8E4M3FN` | `f4E2M1FN` /// quantile-type-spec ::= `:` ((`i` | `u` | `f`) integer-literal | `f8E5M2` | -/// `f8E4M3FN`) +/// `f8E4M3FN` | `f4E2M1FN`) /// expressed-type-spec ::= `:` `f` integer-literal axis-spec ::= /// `:` integer-literal quantiles-list ::= `{` quantile (`,` quantile)* `}` /// scale-zero ::= `:` float-literal `:` integer-literal @@ -641,6 +643,8 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { out << "f8E5M2"; } else if (type.getStorageType().isa()) { out << "f8E4M3FN"; + } else if (type.getStorageType().isa()) { + out << "f4E2M1FN"; } else if (isSigned) { out << "i" << storageWidth; } else { @@ -655,7 +659,9 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { ? QuantizedType::getDefaultMinimumForF8E5M2() : type.getStorageType().isa() ? QuantizedType::getDefaultMinimumForF8E4M3FN() - : std::numeric_limits::max(); + : type.getStorageType().isa() + ? QuantizedType::getDefaultMinimumForF4E2M1FN() + : std::numeric_limits::max(); int64_t defaultMax = type.getStorageType().isa() @@ -664,7 +670,9 @@ static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { ? QuantizedType::getDefaultMaximumForF8E5M2() : type.getStorageType().isa() ? QuantizedType::getDefaultMaximumForF8E4M3FN() - : std::numeric_limits::min(); + : type.getStorageType().isa() + ? QuantizedType::getDefaultMaximumForF4E2M1FN() + : std::numeric_limits::min(); if (defaultMin != type.getStorageTypeMin() || defaultMax != type.getStorageTypeMax()) { @@ -685,6 +693,8 @@ static void printQuantileType(Type quantileType, DialectAsmPrinter &out) { out << ":f8E5M2"; } else if (quantileType.isa()) { out << ":f8E4M3FN"; + } else if (quantileType.isa()) { + out << ":f4E2M1FN"; } else { // Float types out << ":" << quantileType; diff --git a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir index 005faa60e3cb..6e1bacd7d4c4 100644 --- a/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-quantile-invalid.mlir @@ -126,6 +126,16 @@ func.func @parse() -> !qalias { // expected-error@+1 {{illegal storage type minimum: -500}} !qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 10}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -10}} +!qalias = !quant.quantile:f16:f32, {-1.0,1.0}:0.99872:127> + // ----- // Illegal uniform params: invalid scale // expected-error@+1 {{expected floating point literal}} diff --git a/mlir/test/Dialect/Quant/parse-quantile.mlir b/mlir/test/Dialect/Quant/parse-quantile.mlir index 1af567478e71..bb20499eb74d 100644 --- a/mlir/test/Dialect/Quant/parse-quantile.mlir +++ b/mlir/test/Dialect/Quant/parse-quantile.mlir @@ -46,6 +46,15 @@ func.func @parse() -> !qalias { return %0 : !qalias } +// ----- +// Default min/max value optimization for f4E2M1FN. +// CHECK: !quant.quantile +!qalias = !quant.quantile:f16:f32, {-1.0000,-0.8667,-0.7333,-0.6000,-0.4667,-0.3333,-0.2000,-0.0667,0.0667,0.2000,0.3333,0.4667,0.6000,0.7333,0.8667,1.0000}:0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Required per-layer params specified: // [unsigned] storageType, expressedType, scale @@ -92,6 +101,15 @@ func.func @parse() -> !qalias { return %0 : !qalias } +// ----- +// Storage type: f4E2M1FN +// CHECK: !quant.quantile +!qalias = !quant.quantile +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Expressed type: f32 // CHECK: !quant.quantile diff --git a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir index 5553aabe4599..c40273d57b1c 100644 --- a/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir @@ -100,6 +100,16 @@ // expected-error@+1 {{illegal storage type minimum: -500}} !qalias = !quant.uniform:f32, 0.99872:127> +// ----- +// Illegal storage min/max: max > defaultMax +// expected-error@+1 {{illegal storage type maximum: 10}} +!qalias = !quant.uniform:f32, 0.99872:127> + +// ----- +// Illegal storage min/max: min < defaultMin +// expected-error@+1 {{illegal storage type minimum: -10}} +!qalias = !quant.uniform:f32, 0.99872:127> + // ----- // Illegal uniform params: invalid scale // expected-error@+1 {{expected floating point literal}} diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir index 383830d7f1b1..e7cee25c6413 100644 --- a/mlir/test/Dialect/Quant/parse-uniform.mlir +++ b/mlir/test/Dialect/Quant/parse-uniform.mlir @@ -46,6 +46,15 @@ func.func @parse() -> !qalias { return %0 : !qalias } +// ----- +// Default min/max value optimization for f4E2M1FN. +// CHECK: !quant.uniform +!qalias = !quant.uniform:f32, 0.99872:127 > +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Required per-layer params specified: // [unsigned] storageType, expressedType, scale @@ -92,6 +101,15 @@ func.func @parse() -> !qalias { return %0 : !qalias } +// ----- +// Storage type: f4E2M1FN +// CHECK: !quant.uniform +!qalias = !quant.uniform +func.func @parse() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +} + // ----- // Storage type: i16 // CHECK: !quant.uniform