diff --git a/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala b/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala index a8d896443..e53739b05 100644 --- a/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala +++ b/jvm/src/main/scala/io/kaitai/struct/JavaMain.scala @@ -57,6 +57,10 @@ object JavaMain { } } + opt[Unit]('w', "read-write") action { (x, c) => + c.copy(runtime = c.runtime.copy(readWrite = true, autoRead = false)) + } text("generate read-write support in classes (implies `--no-auto-read --zero-copy-substream false`, Java and Python only, default: read-only)") + opt[File]('d', "outdir") valueName("<directory>") action { (x, c) => c.copy(outDir = x) } text("output directory (filenames will be auto-generated); on Unix-like shells, the short form `-d` requires arguments to be preceded by `--`") @@ -172,7 +176,13 @@ object JavaMain { version("version") text("output version information and exit") } - parser.parse(args, CLIConfig()) + parser.parse(args, CLIConfig()).map { c => + if (c.runtime.readWrite) { + c.copy(runtime = c.runtime.copy(zeroCopySubstream = false)) + } else { + c + } + } } /** diff --git a/jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala b/jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala index 522b6c226..88ea57455 100644 --- a/jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala +++ b/jvm/src/test/scala/io/kaitai/struct/translators/TranslatorSpec.scala @@ -697,7 +697,7 @@ class TranslatorSpec extends AnyFunSuite { CppCompiler -> new CppTranslator(tp, new CppImportList(), new CppImportList(), RuntimeConfig()), CSharpCompiler -> new CSharpTranslator(tp, new ImportList()), GoCompiler -> new GoTranslator(goOutput, tp, new ImportList()), - JavaCompiler -> new JavaTranslator(tp, new ImportList()), + JavaCompiler -> new JavaTranslator(tp, new ImportList(), RuntimeConfig()), JavaScriptCompiler -> new JavaScriptTranslator(tp), LuaCompiler -> new LuaTranslator(tp, new ImportList()), PerlCompiler -> new PerlTranslator(tp, new ImportList()), diff --git a/shared/src/main/scala/io/kaitai/struct/ClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/ClassCompiler.scala index f7aa87f56..90a9524bc 100644 --- a/shared/src/main/scala/io/kaitai/struct/ClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/ClassCompiler.scala @@ -78,6 +78,15 @@ class ClassCompiler( // Read method(s) compileEagerRead(curClass.seq, curClass.meta.endian) + compileFetchInstancesProc(curClass.seq ++ curClass.instances.values.collect { + case inst: AttrLikeSpec => inst + }) + + if (config.readWrite) { + compileWrite(curClass.seq, curClass.instances, curClass.meta.endian) + compileCheck(curClass.seq) + } + // Destructor compileDestructor(curClass) @@ -132,6 +141,15 @@ class ClassCompiler( ) compileInit(curClass) curClass.instances.foreach { case (instName, _) => lang.instanceClear(instName) } + if (config.readWrite) { + curClass.instances.foreach { case (instName, instSpec) => + instSpec match { + case _: ParseInstanceSpec => + lang.instanceWriteFlagInit(instName) + case _: ValueInstanceSpec => // do nothing + } + } + } if (lang.config.autoRead) lang.runRead(curClass.name) lang.classConstructorFooter @@ -217,6 +235,9 @@ class ClassCompiler( attr.isNullable } lang.attributeReader(attr.id, attr.dataTypeComposite, isNullable) + if (config.readWrite) { + lang.attributeSetter(attr.id, attr.dataTypeComposite, isNullable) + } } } @@ -236,25 +257,45 @@ class ClassCompiler( def compileEagerRead(seq: List[AttrSpec], endian: Option[Endianness]): Unit = { endian match { case None | Some(_: FixedEndian) => - compileSeqProc(seq, None) + compileSeqReadProc(seq, None) case Some(ce: CalcEndian) => lang.readHeader(None, false) compileCalcEndian(ce) lang.runReadCalc() lang.readFooter() - compileSeqProc(seq, Some(LittleEndian)) - compileSeqProc(seq, Some(BigEndian)) + compileSeqReadProc(seq, Some(LittleEndian)) + compileSeqReadProc(seq, Some(BigEndian)) case Some(InheritedEndian) => lang.readHeader(None, false) lang.runReadCalc() lang.readFooter() - compileSeqProc(seq, Some(LittleEndian)) - compileSeqProc(seq, Some(BigEndian)) + compileSeqReadProc(seq, Some(LittleEndian)) + compileSeqReadProc(seq, Some(BigEndian)) + } + } + + def compileWrite(seq: List[AttrSpec], instances: Map[InstanceIdentifier, InstanceSpec], endian: Option[Endianness]): Unit = { + endian match { + case None | Some(_: FixedEndian) => + compileSeqWriteProc(seq, instances, None) + case Some(CalcEndian(_, _)) | Some(InheritedEndian) => + lang.writeHeader(None, false) + lang.runWriteCalc() + lang.writeFooter() + + compileSeqWriteProc(seq, instances, Some(LittleEndian)) + compileSeqWriteProc(seq, instances, Some(BigEndian)) } } + def compileCheck(seq: List[AttrSpec]): Unit = { + lang.checkHeader() + compileSeqCheck(seq) + lang.checkFooter() + } + val IS_LE_ID = SpecialIdentifier("_is_le") /** @@ -276,18 +317,31 @@ class ClassCompiler( * @param seq sequence of attributes * @param defEndian default endianness */ - def compileSeqProc(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { + def compileSeqReadProc(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { lang.readHeader(defEndian, seq.isEmpty) - compileSeq(seq, defEndian) + compileSeqRead(seq, defEndian) lang.readFooter() } + def compileFetchInstancesProc(attrs: List[AttrLikeSpec]) = { + lang.fetchInstancesHeader() + compileFetchInstances(attrs) + lang.fetchInstancesFooter() + } + + def compileSeqWriteProc(seq: List[AttrSpec], instances: Map[InstanceIdentifier, InstanceSpec], defEndian: Option[FixedEndian]) = { + lang.writeHeader(defEndian, !instances.values.exists(i => i.isInstanceOf[ParseInstanceSpec]) && seq.isEmpty) + compileSetInstanceWriteFlags(instances) + compileSeqWrite(seq, defEndian) + lang.writeFooter() + } + /** * Compiles seq reading method body (only reading statements). * @param seq sequence of attributes * @param defEndian default endianness */ - def compileSeq(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { + def compileSeqRead(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { var wasUnaligned = false seq.foreach { (attr) => val nowUnaligned = isUnalignedBits(attr.dataType) @@ -298,6 +352,34 @@ class ClassCompiler( } } + def compileFetchInstances(attrs: List[AttrLikeSpec]): Unit = { + attrs.foreach { (attr) => + lang.attrFetchInstances(attr, attr.id) + } + } + + def compileSetInstanceWriteFlags(instances: Map[InstanceIdentifier, InstanceSpec]) = { + instances.foreach { case (instName, instSpec) => + instSpec match { + case _: ParseInstanceSpec => + lang.instanceSetWriteFlag(instName) + case _: ValueInstanceSpec => // do nothing + } + } + } + + def compileSeqWrite(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { + seq.foreach { (attr) => + lang.attrWrite(attr, attr.id, defEndian) + } + } + + def compileSeqCheck(seq: List[AttrSpec]) = { + seq.foreach { (attr) => + lang.attrCheck(attr, attr.id) + } + } + /** * Compiles all enums specifications for a given type. * @param curClass current type to generate code for @@ -329,6 +411,12 @@ class ClassCompiler( lang.instanceHeader(className, instName, dataType, instSpec.isNullable) if (lang.innerDocstrings) compileInstanceDoc(instName, instSpec) + if (config.readWrite) + instSpec match { + case _: ParseInstanceSpec => + lang.instanceCheckWriteFlagAndWrite(instName) + case _: ValueInstanceSpec => // do nothing + } lang.instanceCheckCacheAndReturn(instName, dataType) instSpec match { @@ -343,6 +431,22 @@ class ClassCompiler( lang.instanceReturn(instName, dataType) lang.instanceFooter + + if (config.readWrite) + instSpec match { + case pi: ParseInstanceSpec => + lang.attributeSetter(instName, dataType, instSpec.isNullable) + lang.instanceToWriteSetter(instName) + lang.writeInstanceHeader(instName) + lang.attrWrite(pi, instName, endian) + lang.writeInstanceFooter + + lang.checkInstanceHeader(instName) + lang.attrCheck(pi, instName) + lang.checkInstanceFooter + case _: ValueInstanceSpec => + lang.instanceInvalidate(instName) + } } def compileInstanceDeclaration(instName: InstanceIdentifier, instSpec: InstanceSpec): Unit = { @@ -355,6 +459,12 @@ class ClassCompiler( instSpec.isNullable } lang.instanceDeclaration(instName, instSpec.dataTypeComposite, isNullable) + if (config.readWrite) + instSpec match { + case _: ParseInstanceSpec => + lang.instanceWriteFlagDeclaration(instName) + case _: ValueInstanceSpec => // do nothing + } } def compileEnum(curClass: ClassSpec, enumColl: EnumSpec): Unit = diff --git a/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala b/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala index abed472ff..c376c985e 100644 --- a/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala +++ b/shared/src/main/scala/io/kaitai/struct/ClassTypeProvider.scala @@ -57,6 +57,33 @@ class ClassTypeProvider(classSpecs: ClassSpecs, var topClass: ClassSpec) extends case NamedIdentifier(name) => return determineType(inClass, name) case InstanceIdentifier(name) => return determineType(inClass, name) case SpecialIdentifier(name) => return determineType(inClass, name) + case RawIdentifier(innerId) => { + val innerType = determineType(innerId) + val (isArray, itemType: DataType) = innerType match { + case at: ArrayType => (true, at.elType) + case t => (false, t) + } + val singleType: DataType = itemType match { + case st: SwitchType => st.cases.collectFirst { + case (_, caseType) + if caseType.isInstanceOf[BytesType] + || caseType.isInstanceOf[UserTypeFromBytes] => caseType + }.get + case t => t + } + /** see [[languages.components.ExtraAttrs$]] for possible types */ + val bytesType = singleType match { + case bt: BytesType => bt + case utb: UserTypeFromBytes => utb.bytes + } + return if (isArray) ArrayTypeInStream(bytesType) else bytesType + } + case OuterSizeIdentifier(innerId) => + val singleType = CalcIntType + return if (determineType(innerId).isInstanceOf[ArrayType]) ArrayTypeInStream(singleType) else singleType + case InnerSizeIdentifier(innerId) => + val singleType = CalcIntType + return if (determineType(innerId).isInstanceOf[ArrayType]) ArrayTypeInStream(singleType) else singleType case _ => // do nothing } throw new FieldNotFoundError(attrId.humanReadable, inClass) diff --git a/shared/src/main/scala/io/kaitai/struct/DocClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/DocClassCompiler.scala index 685cd1c38..339a6c20b 100644 --- a/shared/src/main/scala/io/kaitai/struct/DocClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/DocClassCompiler.scala @@ -37,7 +37,7 @@ abstract class DocClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) ext classHeader(curClass) // Sequence - compileSeq(curClass) + compileSeqRead(curClass) // Instances curClass.instances.foreach { case (_, instSpec) => @@ -58,7 +58,7 @@ abstract class DocClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) ext classFooter(curClass) } - def compileSeq(curClass: ClassSpec): Unit = { + def compileSeqRead(curClass: ClassSpec): Unit = { seqHeader(curClass) CalculateSeqSizes.forEachSeqAttr(curClass, (attr, seqPos, sizeElement, sizeContainer) => { diff --git a/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala index b41a45023..0a59a83fa 100644 --- a/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/GraphvizClassCompiler.scala @@ -59,7 +59,7 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends out.puts // Sequence - compileSeq(className, curClass) + compileSeqRead(className, curClass) curClass.instances.foreach { case (instName, instSpec) => instSpec match { @@ -84,7 +84,7 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends out.puts("}") } - def compileSeq(className: List[String], curClass: ClassSpec): Unit = { + def compileSeqRead(className: List[String], curClass: ClassSpec): Unit = { tableStart(className, "seq") CalculateSeqSizes.forEachSeqAttr(curClass, (attr, seqPos, _, _) => { diff --git a/shared/src/main/scala/io/kaitai/struct/NimClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/NimClassCompiler.scala index a3f308dfc..b9d9091fb 100644 --- a/shared/src/main/scala/io/kaitai/struct/NimClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/NimClassCompiler.scala @@ -51,17 +51,17 @@ class NimClassCompiler( override def compileEagerRead(seq: List[AttrSpec], endian: Option[Endianness]): Unit = { endian match { case None | Some(_: FixedEndian) => - compileSeqProc(seq, None) + compileSeqReadProc(seq, None) case Some(ce: CalcEndian) => - compileSeqProc(seq, Some(LittleEndian)) - compileSeqProc(seq, Some(BigEndian)) + compileSeqReadProc(seq, Some(LittleEndian)) + compileSeqReadProc(seq, Some(BigEndian)) lang.readHeader(None, false) compileCalcEndian(ce) lang.runReadCalc() lang.readFooter() case Some(InheritedEndian) => - compileSeqProc(seq, Some(LittleEndian)) - compileSeqProc(seq, Some(BigEndian)) + compileSeqReadProc(seq, Some(LittleEndian)) + compileSeqReadProc(seq, Some(BigEndian)) lang.readHeader(None, false) lang.runReadCalc() lang.readFooter() @@ -69,7 +69,7 @@ class NimClassCompiler( } // Must override just to add attribute docstrings - override def compileSeq(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { + override def compileSeqRead(seq: List[AttrSpec], defEndian: Option[FixedEndian]) = { var wasUnaligned = false seq.foreach { (attr) => val nowUnaligned = isUnalignedBits(attr.dataType) @@ -153,7 +153,7 @@ class NimClassCompiler( def compileTypesRec(curClass: ClassSpec): Unit = { curClass.types.foreach { case (_, subClass) => compileTypes(subClass) } } - + // def compileEnumConstants(curClass: ClassSpec): Unit = { // provider.nowClass = curClass // curClass.enums.foreach { case(_, enumColl) => { @@ -203,4 +203,3 @@ class NimClassCompiler( } } - diff --git a/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala b/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala index e31d44b3a..653692b50 100644 --- a/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala +++ b/shared/src/main/scala/io/kaitai/struct/RuntimeConfig.scala @@ -99,6 +99,7 @@ case class RuntimeConfig( readStoresPos: Boolean = false, opaqueTypes: Boolean = false, zeroCopySubstream: Boolean = true, + readWrite: Boolean = false, cppConfig: CppRuntimeConfig = CppRuntimeConfig(), goPackage: String = "", java: JavaRuntimeConfig = JavaRuntimeConfig(), diff --git a/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala b/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala index a148abc43..8deaedc7f 100644 --- a/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/RustClassCompiler.scala @@ -66,7 +66,7 @@ class RustClassCompiler( lang.readHeader(defEndian, false) - compileSeq(curClass.seq, defEndian) + compileSeqRead(curClass.seq, defEndian) lang.classConstructorFooter } diff --git a/shared/src/main/scala/io/kaitai/struct/format/Identifier.scala b/shared/src/main/scala/io/kaitai/struct/format/Identifier.scala index 02b89f9c7..a3f02d8a4 100644 --- a/shared/src/main/scala/io/kaitai/struct/format/Identifier.scala +++ b/shared/src/main/scala/io/kaitai/struct/format/Identifier.scala @@ -20,7 +20,7 @@ abstract class Identifier { */ case class NumberedIdentifier(idx: Int) extends Identifier { import NumberedIdentifier._ - override def humanReadable: String = s"${TEMPLATE}_$idx" + override def humanReadable: String = s"_${TEMPLATE}$idx" } object NumberedIdentifier { @@ -90,6 +90,14 @@ case class IoStorageIdentifier(innerId: Identifier) extends Identifier { override def humanReadable: String = s"io(${innerId.humanReadable})" } +case class OuterSizeIdentifier(innerId: Identifier) extends Identifier { + override def humanReadable: String = s"outerSize(${innerId.humanReadable})" +} + +case class InnerSizeIdentifier(innerId: Identifier) extends Identifier { + override def humanReadable: String = s"innerSize(${innerId.humanReadable})" +} + case class InstanceIdentifier(name: String) extends Identifier { Identifier.checkIdentifier(name) diff --git a/shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala index 991b0e7d4..b2731e557 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/CSharpCompiler.scala @@ -17,7 +17,6 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with AllocateIOLocalVar with EveryReadIsExpression with UniversalDoc - with FixedContentsUsingArrayByteLiteral with SwitchIfOps with NoNeedForFullClassPath { import CSharpCompiler._ @@ -193,9 +192,6 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - out.puts(s"${privateMemberName(attrName)} = $normalIO.EnsureFixedContents($contents);") - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) @@ -272,7 +268,7 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.inc } - override def condIfFooter(expr: expr): Unit = fileFooter(null) + override def condIfFooter: Unit = fileFooter(null) override def condRepeatInitAttr(id: Identifier, dataType: DataType): Unit = { importList.add("System.Collections.Generic") @@ -580,17 +576,21 @@ class CSharpCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def ksErrorName(err: KSError): String = CSharpCompiler.ksErrorName(err) override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") + val errArgsStr = expected.map(expression) ++ List( + expression(Ast.expr.InternalName(attr.id)), + if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "null", + expression(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + ) out.puts(s"if (!(${translator.translate(checkExpr)}))") out.puts("{") out.inc - out.puts(s"throw new ${ksErrorName(err)}($errArgsStr);") + out.puts(s"throw new ${ksErrorName(err)}(${errArgsStr.mkString(", ")});") out.dec out.puts("}") } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala index afef98f43..b6d93d31e 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/CppCompiler.scala @@ -16,7 +16,6 @@ class CppCompiler( ) extends LanguageCompiler(typeProvider, config) with ObjectOrientedLanguage with AllocateAndStoreIO - with FixedContentsUsingArrayByteLiteral with UniversalDoc with SwitchIfOps with EveryReadIsExpression { @@ -452,9 +451,6 @@ class CppCompiler( outSrc.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - outSrc.puts(s"${privateMemberName(attrName)} = $normalIO->ensure_fixed_contents($contents);") - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) @@ -570,7 +566,7 @@ class CppCompiler( outSrc.inc } - override def condIfFooter(expr: Ast.expr): Unit = { + override def condIfFooter: Unit = { outSrc.dec outSrc.puts("}") } @@ -1018,17 +1014,21 @@ class CppCompiler( } override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") + val errArgsStr = expected.map(expression) ++ List( + expression(Ast.expr.InternalName(attr.id)), + if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else nullPtr, + expression(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + ) importListSrc.addKaitai("kaitai/exceptions.h") outSrc.puts(s"if (!(${translator.translate(checkExpr)})) {") outSrc.inc - outSrc.puts(s"throw ${ksErrorName(err)}($errArgsStr);") + outSrc.puts(s"throw ${ksErrorName(err)}(${errArgsStr.mkString(", ")});") outSrc.dec outSrc.puts("}") } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala index 4afa8dae4..90b4d1637 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/GoCompiler.scala @@ -190,25 +190,6 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: Array[Byte]): Unit = { - out.puts(s"${privateMemberName(attrName)}, err = $normalIO.ReadBytes(${contents.length})") - - out.puts(s"if err != nil {") - out.inc - out.puts("return err") - out.dec - out.puts("}") - - importList.add("bytes") - importList.add("errors") - val expected = translator.resToStr(translator.doByteArrayLiteral(contents)) - out.puts(s"if !bytes.Equal(${privateMemberName(attrName)}, $expected) {") - out.inc - out.puts("return errors.New(\"Unexpected fixed contents\")") - out.dec - out.puts("}") - } - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) @@ -550,16 +531,20 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def ksErrorName(err: KSError): String = GoCompiler.ksErrorName(err) override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") + val errArgsStr = expected.map(expression) ++ List( + expression(Ast.expr.InternalName(attr.id)), + if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "nil", + expression(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + ) out.puts(s"if !(${translator.translate(checkExpr)}) {") out.inc - val errInst = s"kaitai.New${err.name}($errArgsStr)" + val errInst = s"kaitai.New${err.name}(${errArgsStr.mkString(", ")})" val noValueAndErr = translator.returnRes match { case None => errInst case Some(r) => s"$r, $errInst" diff --git a/shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala index 231b4aeab..9e1780d1e 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/JavaCompiler.scala @@ -15,15 +15,17 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with UpperCamelCaseClasses with ObjectOrientedLanguage with EveryReadIsExpression + with FetchInstances + with EveryWriteIsExpression + with GenericChecks with UniversalFooter with UniversalDoc with AllocateIOLocalVar - with FixedContentsUsingArrayByteLiteral with SwitchIfOps with NoNeedForFullClassPath { import JavaCompiler._ - val translator = new JavaTranslator(typeProvider, importList) + val translator = new JavaTranslator(typeProvider, importList, config) // Preprocess fromFileClass and make import val fromFileClass = { @@ -38,6 +40,10 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } } + /** See [[subIOWriteBackHeader]] => the code generated when `true` will be inside the definition + * of the "writeBackHandler" callback function. */ + private var inSubIOWriteBackHandler = false + override def universalFooter: Unit = { out.dec out.puts("}") @@ -72,7 +78,10 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) "" } - out.puts(s"public ${staticStr}class ${type2class(name)} extends $kstructName {") + out.puts( + s"public ${staticStr}class ${type2class(name)} " + + s"extends $kstructNameFull {" + ) out.inc if (config.readStoresPos) { @@ -131,6 +140,14 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) val paramsRelay = Utils.join(params.map((p) => paramName(p.id)), ", ", ", ", "") + if (config.readWrite) { + out.puts(s"public ${type2class(name)}(${paramsArg.stripPrefix(", ")}) {") + out.inc + out.puts(s"this(null, null, null$paramsRelay);") + out.dec + out.puts("}") + } + out.puts out.puts(s"public ${type2class(name)}($kstreamName _io$paramsArg) {") out.inc @@ -181,28 +198,100 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } + override def runWriteCalc(): Unit = { + out.puts + out.puts("if (_is_le == null) {") + out.inc + out.puts(s"throw new $kstreamName.UndecidedEndiannessError();") + out.dec + out.puts("} else if (_is_le) {") + out.inc + out.puts("_write_SeqLE();") + out.dec + out.puts("} else {") + out.inc + out.puts("_write_SeqBE();") + out.dec + out.puts("}") + } + override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean) = { - val readAccessAndType = if (!config.autoRead) { - "public" - } else { - "private" + endian match { + case Some(e) => + out.puts(s"private void _read${Utils.upperUnderscoreCase(e.toSuffix)}() {") + case None => + out.puts(s"${if (!config.autoRead) "public" else "private"} void _read() {") } - val suffix = endian match { - case Some(e) => Utils.upperUnderscoreCase(e.toSuffix) - case None => "" + out.inc + } + + override def fetchInstancesHeader(): Unit = { + out.puts + out.puts("public void _fetchInstances() {") + out.inc + } + + override def fetchInstancesFooter: Unit = universalFooter + + override def attrInvokeFetchInstances(baseExpr: Ast.expr, exprType: DataType, dataType: DataType): Unit = { + val expr = castIfNeeded(expression(baseExpr), exprType, dataType) + out.puts(s"$expr._fetchInstances();") + } + + override def attrInvokeInstance(instName: InstanceIdentifier): Unit = { + out.puts(s"${publicMemberName(instName)}();") + } + + override def writeHeader(endian: Option[FixedEndian], isEmpty: Boolean): Unit = { + out.puts + endian match { + case Some(e) => + out.puts(s"private void _write_Seq${Utils.upperUnderscoreCase(e.toSuffix)}() {") + case None => + out.puts("public void _write_Seq() {") } - out.puts(s"$readAccessAndType void _read$suffix() {") out.inc } - override def readFooter(): Unit = universalFooter + override def checkHeader(): Unit = { + out.puts + out.puts("public void _check() {") + out.inc + } + + override def writeInstanceHeader(instName: InstanceIdentifier): Unit = { + out.puts + out.puts(s"public void _write${idToSetterStr(instName)}() {") + out.inc + instanceClearWriteFlag(instName) + } + + override def checkInstanceHeader(instName: InstanceIdentifier): Unit = { + out.puts + out.puts(s"public void _check${idToSetterStr(instName)}() {") + out.inc + } override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { out.puts(s"private ${kaitaiType2JavaType(attrType, isNullable)} ${idToStr(attrName)};") } override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { - out.puts(s"public ${kaitaiType2JavaType(attrType, isNullable)} ${idToStr(attrName)}() { return ${idToStr(attrName)}; }") + val javaType = kaitaiType2JavaType(attrType, isNullable) + val name = idToStr(attrName) + + out.puts(s"public $javaType $name() { return $name; }") + } + + override def attributeSetter(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { + val javaType = kaitaiType2JavaType(attrType, isNullable) + val name = idToStr(attrName) + + out.puts(s"public void set${idToSetterStr(attrName)}($javaType _v) { $name = _v; }") + } + + override def attrSetProperty(base: Ast.expr, propName: Identifier, value: String): Unit = { + out.puts(s"${expression(base)}.set${idToSetterStr(propName)}($value);") } override def universalDoc(doc: DocSpec): Unit = { @@ -233,10 +322,6 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = { - out.puts(s"${privateMemberName(attrName)} = $normalIO.ensureFixedContents($contents);") - } - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) @@ -268,6 +353,67 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) handleAssignment(varDest, expr, rep, false) } + override def attrUnprocess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec, dataType: BytesType, exprTypeOpt: Option[DataType]): Unit = { + val exprType = exprTypeOpt.getOrElse(dataType) + val srcExprRaw = varSrc match { + // use `_raw_items[_raw_items.size - 1]` + case _: RawIdentifier => getRawIdExpr(varSrc, rep) + // but `items[_index]` + case _ => expression(itemExpr(varSrc, rep)) + } + val srcExpr = castIfNeeded(srcExprRaw, exprType, dataType) + + val expr = proc match { + case ProcessXor(xorValue) => + val argStr = if (inSubIOWriteBackHandler) "_processXorArg" else expression(xorValue) + val xorValueStr = translator.detectType(xorValue) match { + case _: IntType => castIfNeeded(argStr, AnyType, Int1Type(true)) + case _ => argStr + } + s"$kstreamName.processXor($srcExpr, $xorValueStr)" + case ProcessZlib => + s"$kstreamName.unprocessZlib($srcExpr)" + case ProcessRotate(isLeft, rotValue) => + val argStr = if (inSubIOWriteBackHandler) "_processRotateArg" else expression(rotValue) + val expr = if (!isLeft) { + argStr + } else { + s"8 - ($argStr)" + } + s"$kstreamName.processRotateLeft($srcExpr, $expr, 1)" + case ProcessCustom(name, args) => + val namespace = name.init.mkString(".") + val procClass = namespace + + (if (namespace.nonEmpty) "." else "") + + type2class(name.last) + val procName = s"_process_${idToStr(varSrc)}" + if (!inSubIOWriteBackHandler) { + out.puts(s"$procClass $procName = new $procClass(${args.map(expression).mkString(", ")});") + } + s"$procName.encode($srcExpr)" + } + handleAssignment(varDest, expr, rep, false) + } + + override def attrUnprocessPrepareBeforeSubIOHandler(proc: ProcessExpr, varSrc: Identifier): Unit = { + proc match { + case ProcessXor(xorValue) => + val dataType = translator.detectType(xorValue) + out.puts(s"final ${kaitaiType2JavaType(dataType)} _processXorArg = ${expression(xorValue)};") + case ProcessRotate(_, rotValue) => + val dataType = translator.detectType(rotValue) + out.puts(s"final ${kaitaiType2JavaType(dataType)} _processRotateArg = ${expression(rotValue)};") + case ProcessZlib => // no process arguments + case ProcessCustom(name, args) => + val namespace = name.init.mkString(".") + val procClass = namespace + + (if (namespace.nonEmpty) "." else "") + + type2class(name.last) + val procName = s"_process_${idToStr(varSrc)}" + out.puts(s"final $procClass $procName = new $procClass(${args.map(expression).mkString(", ")});") + } + } + override def allocateIO(varName: Identifier, rep: RepeatSpec): String = { val ioName = idToStr(IoStorageIdentifier(varName)) @@ -281,8 +427,47 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) ioName } + override def allocateIOFixed(varName: Identifier, size: String): String = { + val ioName = idToStr(IoStorageIdentifier(varName)) + + out.puts(s"final $kstreamName $ioName = new ByteBufferKaitaiStream($size);") + ioName + } + + override def exprIORemainingSize(io: String): String = + s"$io.size() - $io.pos()" + + override def allocateIOGrowing(varName: Identifier): String = + allocateIOFixed(varName, "100000") // FIXME to use real growing buffer + + override def subIOWriteBackHeader(subIO: String, process: Option[ProcessExpr]): String = { + val parentIoName = "parent" + out.puts(s"final ${type2class(typeProvider.nowClass.name.last)} _this = this;") + out.puts(s"$subIO.setWriteBackHandler(new $kstreamName.WriteBackHandler(_pos2) {") + out.inc + out.puts("@Override") + out.puts(s"protected void write($kstreamName $parentIoName) {") + out.inc + + inSubIOWriteBackHandler = true + + parentIoName + } + + override def subIOWriteBackFooter(subIO: String): Unit = { + inSubIOWriteBackHandler = false + + out.dec + out.puts("}") + out.dec + out.puts("});") + } + + override def addChildIO(io: String, childIO: String): Unit = + out.puts(s"$io.addChildStream($childIO);") + def getRawIdExpr(varName: Identifier, rep: RepeatSpec): String = { - val memberName = idToStr(varName) + val memberName = privateMemberName(varName) rep match { case NoRepeat => memberName case _ => s"$memberName.get($memberName.size() - 1)" @@ -297,14 +482,22 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def pushPos(io: String): Unit = out.puts(s"long _pos = $io.pos();") + override def pushPosForSubIOWriteBackHandler(io: String): Unit = + out.puts(s"long _pos2 = $io.pos();") + override def seek(io: String, pos: Ast.expr): Unit = out.puts(s"$io.seek(${expression(pos)});") + override def seekRelative(io: String, relPos: String): Unit = + out.puts(s"$io.seek($io.pos() + ($relPos));") + override def popPos(io: String): Unit = out.puts(s"$io.seek(_pos);") - override def alignToByte(io: String): Unit = - out.puts(s"$io.alignToByte();") + // NOTE: the compiler does not need to output alignToByte() calls for Java anymore, + // since the byte alignment is handled by the runtime library since commit + // https://github.com/kaitai-io/kaitai_struct_java_runtime/commit/1bc75aa91199588a1cb12a5a1c672b80b66619ac + override def alignToByte(io: String): Unit = {} override def attrDebugStart(attrId: Identifier, attrType: DataType, ios: Option[String], rep: RepeatSpec): Unit = { ios.foreach { (io) => @@ -384,13 +577,19 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) importList.add("java.util.ArrayList") } + // used for all repetitions in _check() + override def condRepeatCommonHeader(id: Identifier, io: String, dataType: DataType): Unit = { + out.puts(s"for (int i = 0; i < ${privateMemberName(id)}.size(); i++) {") + out.inc + } + override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit = handleAssignmentRepeatEos(id, expr) override def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, untilExpr: expr): Unit = { out.puts("{") out.inc - out.puts(s"${kaitaiType2JavaType(dataType)} ${translator.doName("_")};") + out.puts(s"${kaitaiType2JavaType(dataType)} ${translator.doName(Identifier.ITERATOR)};") out.puts("int i = 0;") out.puts("do {") out.inc @@ -462,11 +661,7 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) s"new ${types2class(t.name)}($io$addArgs$addParams)" } - if (assignType != dataType) { - s"(${kaitaiType2JavaType(assignType)}) ($expr)" - } else { - expr - } + castIfNeeded(expr, dataType, assignType) } override def createSubstreamFixedSize(id: Identifier, blt: BytesLimitType, io: String, rep: RepeatSpec, defEndian: Option[FixedEndian]): String = { @@ -493,8 +688,9 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Int], include: Boolean) = { val expr1 = padRight match { - case Some(padByte) => s"$kstreamName.bytesStripRight($expr0, (byte) $padByte)" - case None => expr0 + case Some(padByte) if terminator.map(term => padByte != term).getOrElse(true) => + s"$kstreamName.bytesStripRight($expr0, (byte) $padByte)" + case _ => expr0 } val expr2 = terminator match { case Some(term) => s"$kstreamName.bytesTerminate($expr1, (byte) $term, $include)" @@ -504,11 +700,7 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def userTypeDebugRead(id: String, dataType: DataType, assignType: DataType): Unit = { - val expr = if (assignType != dataType) { - s"((${kaitaiType2JavaType(dataType)}) ($id))" - } else { - id - } + val expr = castIfNeeded(id, assignType, dataType) out.puts(s"$expr._read();") } @@ -653,6 +845,23 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"private ${kaitaiType2JavaTypeBoxed(attrType)} ${idToStr(attrName)};") } + override def instanceWriteFlagDeclaration(attrName: InstanceIdentifier): Unit = { + out.puts(s"private boolean _write${idToSetterStr(attrName)} = false;") + out.puts(s"private boolean _toWrite${idToSetterStr(attrName)} = true;") + } + + override def instanceSetWriteFlag(instName: InstanceIdentifier): Unit = { + out.puts(s"_write${idToSetterStr(instName)} = _toWrite${idToSetterStr(instName)};") + } + + override def instanceClearWriteFlag(instName: InstanceIdentifier): Unit = { + out.puts(s"_write${idToSetterStr(instName)} = false;") + } + + override def instanceToWriteSetter(instName: InstanceIdentifier): Unit = { + out.puts(s"public void set${idToSetterStr(instName)}_ToWrite(boolean _v) { _toWrite${idToSetterStr(instName)} = _v; }") + } + override def instanceHeader(className: String, instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = { out.puts(s"public ${kaitaiType2JavaTypeBoxed(dataType)} ${idToStr(instName)}() {") out.inc @@ -665,6 +874,13 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.dec } + override def instanceCheckWriteFlagAndWrite(instName: InstanceIdentifier): Unit = { + out.puts(s"if (_write${idToSetterStr(instName)})") + out.inc + out.puts(s"_write${idToSetterStr(instName)}();") + out.dec + } + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { out.puts(s"return ${privateMemberName(instName)};") } @@ -673,19 +889,22 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) val primType = kaitaiType2JavaTypePrim(dataType) val boxedType = kaitaiType2JavaTypeBoxed(dataType) - if (primType != boxedType) { - // Special trick to achieve both implicit type conversion + boxing. - // Unfortunately, Java can't do both in one assignment, i.e. this would fail: + if (dataType.isInstanceOf[NumericType]) { + // Special trick to achieve both type conversion + boxing. + // Unfortunately, Java can't do both by itself, i.e. this would fail: // // Double c = 1.0f + 1; - out.puts(s"$primType _tmp = ($primType) (${expression(value)});") - out.puts(s"${privateMemberName(instName)} = _tmp;") + out.puts(s"${privateMemberName(instName)} = ${translator.doCast(value, dataType)};") } else { out.puts(s"${privateMemberName(instName)} = ${expression(value)};") } } + override def instanceInvalidate(instName: InstanceIdentifier): Unit = { + out.puts(s"public void _invalidate${idToSetterStr(instName)}() { ${privateMemberName(instName)} = null; }") + } + override def enumDeclaration(curClass: String, enumName: String, enumColl: Seq[(Long, String)]): Unit = { val enumClass = type2class(enumName) @@ -724,6 +943,13 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) importList.add("java.util.HashMap") } + override def internalEnumIntType(basedOn: IntType): DataType = { + basedOn match { + case IntMultiType(signed, _, endian) => IntMultiType(signed, Width8, endian) + case _ => IntMultiType(true, Width8, None) + } + } + override def debugClassSequence(seq: List[AttrSpec]) = { val seqStr = seq.map((attr) => "\"" + idToStr(attr.id) + "\"").mkString(", ") out.puts(s"public static String[] _seqFields = new String[] { $seqStr };") @@ -739,13 +965,139 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } + override def attrPrimitiveWrite( + io: String, + valueExpr: Ast.expr, + dataType: DataType, + defEndian: Option[FixedEndian], + exprTypeOpt: Option[DataType] + ): Unit = { + val exprType = exprTypeOpt.getOrElse(dataType) + val exprRaw = expression(valueExpr) + val expr = castIfNeeded(exprRaw, exprType, dataType) + + val stmt = dataType match { + case t: ReadableType => + s"$io.write${Utils.capitalize(t.apiCall(defEndian))}($expr)" + case BitsType1(bitEndian) => + s"$io.writeBitsInt${Utils.upperCamelCase(bitEndian.toSuffix)}(1, ${translator.boolToInt(valueExpr)})" + case BitsType(width: Int, bitEndian) => + s"$io.writeBitsInt${Utils.upperCamelCase(bitEndian.toSuffix)}($width, $expr)" + case _: BytesType => + s"$io.writeBytes($expr)" + } + out.puts(stmt + ";") + } + + override def attrBytesLimitWrite(io: String, expr: Ast.expr, size: String, term: Int, padRight: Int): Unit = + out.puts(s"$io.writeBytesLimit(${expression(expr)}, $size, (byte) $term, (byte) $padRight);") + + override def attrUserTypeInstreamWrite(io: String, valueExpr: Ast.expr, dataType: DataType, exprType: DataType) = { + val exprRaw = expression(valueExpr) + val expr = castIfNeeded(exprRaw, exprType, dataType) + out.puts(s"$expr._write_Seq($io);") + } + + override def exprStreamToByteArray(io: String): String = + s"$io.toByteArray()" + + override def attrBasicCheck(checkExpr: Ast.expr, actual: Ast.expr, expected: Ast.expr, msg: String): Unit = { + val msgStr = expression(Ast.expr.Str(msg)) + + out.puts(s"if (${expression(checkExpr)})") + out.inc + out.puts(s"throw new ConsistencyError($msgStr, ${expression(actual)}, ${expression(expected)});") + out.dec + + importList.add("io.kaitai.struct.ConsistencyError") + } + + override def attrObjectsEqualCheck(actual: Ast.expr, expected: Ast.expr, msg: String): Unit = { + val msgStr = expression(Ast.expr.Str(msg)) + + out.puts(s"if (!Objects.equals(${expression(actual)}, ${expression(expected)}))") + out.inc + out.puts(s"throw new ConsistencyError($msgStr, ${expression(actual)}, ${expression(expected)});") + out.dec + + importList.add("java.util.Objects") + importList.add("io.kaitai.struct.ConsistencyError") + } + + override def attrParentParamCheck(actualParentExpr: Ast.expr, ut: UserType, shouldDependOnIo: Option[Boolean], msg: String): Unit = { + if (ut.isOpaque) + return + /** @note Must be kept in sync with [[JavaCompiler.parseExpr]] */ + val (expectedParent, dependsOnIo) = ut.forcedParent match { + case Some(USER_TYPE_NO_PARENT) => ("null", false) + case Some(fp) => + (expression(fp), userExprDependsOnIo(fp)) + case None => ("this", false) + } + if (shouldDependOnIo.map(shouldDepend => dependsOnIo != shouldDepend).getOrElse(false)) + return + + val msgStr = expression(Ast.expr.Str(msg)) + + out.puts(s"if (!Objects.equals(${expression(actualParentExpr)}, $expectedParent))") + out.inc + out.puts(s"throw new ConsistencyError($msgStr, ${expression(actualParentExpr)}, $expectedParent);") + out.dec + + importList.add("java.util.Objects") + importList.add("io.kaitai.struct.ConsistencyError") + } + + override def attrIsEofCheck(io: String, expectedIsEof: Boolean, msg: String): Unit = { + val msgStr = expression(Ast.expr.Str(msg)) + + val eofExpr = s"$io.isEof()" + val ifExpr = if (expectedIsEof) { + s"!($eofExpr)" + } else { + eofExpr + } + out.puts(s"if ($ifExpr)") + out.inc + out.puts(s"throw new ConsistencyError($msgStr, ${exprIORemainingSize(io)}, 0);") + out.dec + + importList.add("io.kaitai.struct.ConsistencyError") + } + + override def condIfIsEofHeader(io: String, wantedIsEof: Boolean): Unit = { + val eofExpr = s"$io.isEof()" + val ifExpr = if (!wantedIsEof) { + s"!($eofExpr)" + } else { + eofExpr + } + + out.puts(s"if ($ifExpr) {") + out.inc + } + + override def condIfIsEofFooter: Unit = universalFooter + def value2Const(s: String) = Utils.upperUnderscoreCase(s) override def idToStr(id: Identifier): String = JavaCompiler.idToStr(id) override def publicMemberName(id: Identifier) = JavaCompiler.publicMemberName(id) - override def privateMemberName(id: Identifier): String = s"this.${idToStr(id)}" + def idToSetterStr(id: Identifier): String = { + id match { + case SpecialIdentifier(name) => name + case NamedIdentifier(name) => Utils.upperCamelCase(name) + case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" + case InstanceIdentifier(name) => Utils.upperCamelCase(name) + case RawIdentifier(innerId) => "_raw_" + idToSetterStr(innerId) + case OuterSizeIdentifier(innerId) => s"${idToSetterStr(innerId)}_OuterSize" + case InnerSizeIdentifier(innerId) => s"${idToSetterStr(innerId)}_InnerSize" + } + } + + override def privateMemberName(id: Identifier): String = s"${if (inSubIOWriteBackHandler) "_" else ""}this.${idToStr(id)}" override def localTemporaryName(id: Identifier): String = s"_t_${idToStr(id)}" @@ -756,40 +1108,31 @@ class JavaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") + val errArgsStr = expected.map(expression) ++ List( + expression(Ast.expr.InternalName(attr.id)), + if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "null", + expression(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + ) out.puts(s"if (!(${translator.translate(checkExpr)})) {") out.inc - out.puts(s"throw new ${ksErrorName(err)}($errArgsStr);") + out.puts(s"throw new ${ksErrorName(err)}(${errArgsStr.mkString(", ")});") out.dec out.puts("}") } -} -object JavaCompiler extends LanguageCompilerStatic - with UpperCamelCaseClasses - with StreamStructNames { - override def getCompiler( - tp: ClassTypeProvider, - config: RuntimeConfig - ): LanguageCompiler = new JavaCompiler(tp, config) - - def idToStr(id: Identifier): String = - id match { - case SpecialIdentifier(name) => name - case NamedIdentifier(name) => Utils.lowerCamelCase(name) - case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" - case InstanceIdentifier(name) => Utils.lowerCamelCase(name) - case RawIdentifier(innerId) => s"_raw_${idToStr(innerId)}" - case IoStorageIdentifier(innerId) => s"_io_${idToStr(innerId)}" - } - - def publicMemberName(id: Identifier) = idToStr(id) + def kstructNameFull: String = { + kstructName + ((config.autoRead, config.readWrite) match { + case (_, true) => ".ReadWrite" + case (false, false) => ".ReadOnly" + case (true, false) => "" + }) + } def kaitaiType2JavaType(attrType: DataType): String = kaitaiType2JavaTypePrim(attrType) @@ -800,6 +1143,21 @@ object JavaCompiler extends LanguageCompilerStatic kaitaiType2JavaTypePrim(attrType) } + def castIfNeeded(exprRaw: String, exprType: DataType, targetType: DataType): String = + if (exprType != targetType) { + val castTypeId = kaitaiType2JavaTypePrim(targetType) + targetType match { + // Handles both unboxing + downcasting at the same time if needed + // (solution from https://github.com/kaitai-io/kaitai_struct_compiler/pull/149) + // + // See also https://github.com/kaitai-io/kaitai_struct_compiler/pull/212#issuecomment-731149487 + case _: NumericType => s"((Number) ($exprRaw)).${castTypeId}Value()" + case _ => s"(($castTypeId) ($exprRaw))" + } + } else { + exprRaw + } + /** * Determine Java data type corresponding to a KS data type. A "primitive" type (i.e. "int", "long", etc) will * be returned if possible. @@ -833,7 +1191,7 @@ object JavaCompiler extends LanguageCompilerStatic case AnyType => "Object" case KaitaiStreamType | OwnedKaitaiStreamType => kstreamName - case KaitaiStructType | CalcKaitaiStructType(_) => kstructName + case KaitaiStructType | CalcKaitaiStructType(_) => kstructNameFull case t: UserType => types2class(t.name) case EnumType(name, _) => types2class(name) @@ -877,7 +1235,7 @@ object JavaCompiler extends LanguageCompilerStatic case AnyType => "Object" case KaitaiStreamType | OwnedKaitaiStreamType => kstreamName - case KaitaiStructType | CalcKaitaiStructType(_) => kstructName + case KaitaiStructType | CalcKaitaiStructType(_) => kstructNameFull case t: UserType => types2class(t.name) case EnumType(name, _) => types2class(name) @@ -888,9 +1246,34 @@ object JavaCompiler extends LanguageCompilerStatic case st: SwitchType => kaitaiType2JavaTypeBoxed(st.combinedType) } } +} + +object JavaCompiler extends LanguageCompilerStatic + with UpperCamelCaseClasses + with StreamStructNames + with ExceptionNames { + override def getCompiler( + tp: ClassTypeProvider, + config: RuntimeConfig + ): LanguageCompiler = new JavaCompiler(tp, config) + + def idToStr(id: Identifier): String = + id match { + case SpecialIdentifier(name) => name + case NamedIdentifier(name) => Utils.lowerCamelCase(name) + case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" + case InstanceIdentifier(name) => Utils.lowerCamelCase(name) + case RawIdentifier(innerId) => s"_raw_${idToStr(innerId)}" + case IoStorageIdentifier(innerId) => s"_io_${idToStr(innerId)}" + case OuterSizeIdentifier(innerId) => s"${idToStr(innerId)}_OuterSize" + case InnerSizeIdentifier(innerId) => s"${idToStr(innerId)}_InnerSize" + } + + def publicMemberName(id: Identifier) = idToStr(id) def types2class(names: List[String]) = names.map(x => type2class(x)).mkString(".") override def kstreamName: String = "KaitaiStream" override def kstructName: String = "KaitaiStruct" + override def ksErrorName(err: KSError): String = s"KaitaiStream.${err.name}" } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/JavaScriptCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/JavaScriptCompiler.scala index 17d6db711..31b07acfe 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/JavaScriptCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/JavaScriptCompiler.scala @@ -17,8 +17,7 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with UniversalDoc with AllocateIOLocalVar with EveryReadIsExpression - with SwitchIfOps - with FixedContentsUsingArrayByteLiteral { + with SwitchIfOps { import JavaScriptCompiler._ override val translator = new JavaScriptTranslator(typeProvider) @@ -184,11 +183,6 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = { - out.puts(s"${privateMemberName(attrName)} = " + - s"$normalIO.ensureFixedContents($contents);") - } - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) @@ -290,7 +284,7 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } // TODO: replace this with UniversalFooter - override def condIfFooter(expr: expr): Unit = { + override def condIfFooter: Unit = { out.dec out.puts("}") } @@ -557,16 +551,20 @@ class JavaScriptCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def ksErrorName(err: KSError): String = JavaScriptCompiler.ksErrorName(err) override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") + val errArgsStr = expected.map(expression) ++ List( + expression(Ast.expr.InternalName(attr.id)), + if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "null", + expression(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + ) out.puts(s"if (!(${translator.translate(checkExpr)})) {") out.inc - out.puts(s"throw new ${ksErrorName(err)}($errArgsStr);") + out.puts(s"throw new ${ksErrorName(err)}(${errArgsStr.mkString(", ")});") out.dec out.puts("}") } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala index 45e9464c8..bf8065a18 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/LuaCompiler.scala @@ -12,7 +12,6 @@ class LuaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) extends LanguageCompiler(typeProvider, config) with AllocateIOLocalVar with EveryReadIsExpression - with FixedContentsUsingArrayByteLiteral with ObjectOrientedLanguage with SingleOutputFile with UniversalDoc @@ -150,14 +149,11 @@ class LuaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("end") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - out.puts(s"${privateMemberName(attrName)} = self._io:ensure_fixed_contents($contents)") - override def condIfHeader(expr: Ast.expr): Unit = { out.puts(s"if ${expression(expr)} then") out.inc } - override def condIfFooter(expr: Ast.expr): Unit = { + override def condIfFooter: Unit = { out.dec out.puts("end") } @@ -402,24 +398,21 @@ class LuaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def ksErrorName(err: KSError): String = LuaCompiler.ksErrorName(err) override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsCode = errArgs.map(translator.translate) + val actualStr = expression(Ast.expr.InternalName(attr.id)) out.puts(s"if not(${translator.translate(checkExpr)}) then") out.inc val msg = err match { case _: ValidationNotEqualError => { - val (expected, actual) = ( - errArgsCode.lift(0).getOrElse("[expected]"), - errArgsCode.lift(1).getOrElse("[actual]") - ) - s""""not equal, expected " .. $expected .. ", but got " .. $actual""" + val expectedStr = expected.get + s""""not equal, expected " .. $expectedStr .. ", but got " .. $actualStr""" } - case _ => "\"" + ksErrorName(err) + "\"" + case _ => expression(Ast.expr.Str(ksErrorName(err))) } out.puts(s"error($msg)") out.dec diff --git a/shared/src/main/scala/io/kaitai/struct/languages/NimCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/NimCompiler.scala index 565ef2700..87bc33b7d 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/NimCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/NimCompiler.scala @@ -13,7 +13,6 @@ class NimCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with SingleOutputFile with EveryReadIsExpression with UpperCamelCaseClasses - with FixedContentsUsingArrayByteLiteral with UniversalFooter with AllocateIOLocalVar with SwitchIfOps @@ -86,9 +85,6 @@ class NimCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) importList.add(file) } override def alignToByte(io: String): Unit = out.puts(s"alignToByte($io)") - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = { - out.puts(s"this.${idToStr(attrName)} = $normalIO.ensureFixedContents($contents)") - } // def attrParse(attr: AttrLikeSpec, id: Identifier, defEndian: Option[Endianness]): Unit = ??? override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = { out.puts("if this.isLe:") diff --git a/shared/src/main/scala/io/kaitai/struct/languages/PHPCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/PHPCompiler.scala index 672fcab3b..6a50fc8ce 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/PHPCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/PHPCompiler.scala @@ -16,7 +16,6 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with AllocateIOLocalVar with UniversalFooter with UniversalDoc - with FixedContentsUsingArrayByteLiteral with EveryReadIsExpression { import PHPCompiler._ @@ -191,9 +190,6 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - out.puts(s"${privateMemberName(attrName)} = $normalIO->ensureFixedContents($contents);") - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) @@ -479,16 +475,20 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def ksErrorName(err: KSError): String = PHPCompiler.ksErrorName(err) override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") + val errArgsStr = expected.map(expression) ++ List( + expression(Ast.expr.InternalName(attr.id)), + if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "null", + expression(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + ) out.puts(s"if (!(${translator.translate(checkExpr)})) {") out.inc - out.puts(s"throw new ${ksErrorName(err)}($errArgsStr);") + out.puts(s"throw new ${ksErrorName(err)}(${errArgsStr.mkString(", ")});") out.dec out.puts("}") } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/PerlCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/PerlCompiler.scala index 732da5622..5b0b8badb 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/PerlCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/PerlCompiler.scala @@ -15,7 +15,6 @@ class PerlCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with UniversalFooter with UpperCamelCaseClasses with AllocateIOLocalVar - with FixedContentsUsingArrayByteLiteral with SwitchIfOps with EveryReadIsExpression { @@ -163,10 +162,6 @@ class PerlCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = { - out.puts(s"${privateMemberName(attrName)} = $normalIO->ensure_fixed_contents($contents);") - } - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) diff --git a/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala index abfda41bb..b58912f17 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/PythonCompiler.scala @@ -16,8 +16,10 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with SingleOutputFile with UniversalFooter with EveryReadIsExpression + with FetchInstances + with EveryWriteIsExpression + with GenericChecks with AllocateIOLocalVar - with FixedContentsUsingArrayByteLiteral with UniversalDoc with SwitchIfOps with NoNeedForFullClassPath { @@ -28,6 +30,10 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def innerDocstrings = true + /** See [[subIOWriteBackHeader]] => the code generated when `true` will be inside the definition + * of the "write back handler" callback function. */ + private var inSubIOWriteBackHandler = false + override def universalFooter: Unit = { out.dec out.puts @@ -49,7 +55,7 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) outHeader.puts importList.add("import kaitaistruct") - importList.add(s"from kaitaistruct import $kstructName, $kstreamName, BytesIO") + importList.add(s"from kaitaistruct import $kstructNameFull, $kstreamName, BytesIO") out.puts out.puts @@ -89,7 +95,7 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def classHeader(name: String): Unit = { - out.puts(s"class ${type2class(name)}($kstructName):") + out.puts(s"class ${type2class(name)}($kstructNameFull):") out.inc } @@ -97,11 +103,16 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) val endianAdd = if (isHybrid) ", _is_le=None" else "" val paramsList = Utils.join(params.map((p) => paramName(p.id)), ", ", ", ", "") - out.puts(s"def __init__(self$paramsList, _io, _parent=None, _root=None$endianAdd):") + val ioDefaultVal = if (config.readWrite) "=None" else "" + out.puts(s"def __init__(self$paramsList, _io$ioDefaultVal, _parent=None, _root=None$endianAdd):") out.inc out.puts("self._io = _io") out.puts("self._parent = _parent") - out.puts("self._root = _root if _root else self") + if (name == rootClassName) { + out.puts("self._root = _root if _root else self") + } else { + out.puts("self._root = _root") + } if (isHybrid) out.puts("self._is_le = _is_le") @@ -120,11 +131,11 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) } override def runReadCalc(): Unit = { - out.puts(s"if not hasattr(self, '_is_le'):") + out.puts("if not hasattr(self, '_is_le'):") out.inc out.puts(s"raise ${ksErrorName(UndecidedEndiannessError)}(" + "\"" + typeProvider.nowClass.path.mkString("/", "/", "") + "\")") out.dec - out.puts(s"elif self._is_le == True:") + out.puts("elif self._is_le == True:") out.inc out.puts("self._read_le()") out.dec @@ -134,6 +145,21 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.dec } + override def runWriteCalc(): Unit = { + out.puts("if not hasattr(self, '_is_le'):") + out.inc + out.puts(s"raise ${ksErrorName(UndecidedEndiannessError)}(" + "\"" + typeProvider.nowClass.path.mkString("/", "/", "") + "\")") + out.dec + out.puts("elif self._is_le == True:") + out.inc + out.puts("self._write__seq_le()") + out.dec + out.puts("elif self._is_le == False:") + out.inc + out.puts("self._write__seq_be()") + out.dec + } + override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean): Unit = { val suffix = endian match { case Some(e) => s"_${e.toSuffix}" @@ -145,12 +171,82 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("pass") } - override def readFooter() = universalFooter + override def fetchInstancesHeader(): Unit = { + out.puts + out.puts("def _fetch_instances(self):") + out.inc + out.puts("pass") + } + + override def fetchInstancesFooter: Unit = universalFooter + + override def attrInvokeFetchInstances(baseExpr: Ast.expr, exprType: DataType, dataType: DataType): Unit = { + val expr = expression(baseExpr) + out.puts(s"$expr._fetch_instances()") + } + + override def attrInvokeInstance(instName: InstanceIdentifier): Unit = { + out.puts(s"_ = self.${publicMemberName(instName)}") + } + + override def writeHeader(endian: Option[FixedEndian], isEmpty: Boolean): Unit = { + out.puts + endian match { + case Some(e) => + out.puts(s"def _write__seq_${e.toSuffix}(self):") + out.inc + if (isEmpty) + out.puts("pass") + case None => + out.puts("def _write__seq(self, io=None):") + out.inc + // FIXME: remove super() args when dropping support for Python 2 (see + // https://pylint.readthedocs.io/en/v2.16.2/user_guide/messages/refactor/super-with-arguments.html) + out.puts(s"super(${types2class(typeProvider.nowClass.name)}, self)._write__seq(io)") + } + } + + override def checkHeader(): Unit = { + out.puts + out.puts("def _check(self):") + out.inc + out.puts("pass") + } + + override def writeInstanceHeader(instName: InstanceIdentifier): Unit = { + out.puts + out.puts(s"def _write_${publicMemberName(instName)}(self):") + out.inc + instanceClearWriteFlag(instName) + } + + override def checkInstanceHeader(instName: InstanceIdentifier): Unit = { + out.puts + out.puts(s"def _check_${publicMemberName(instName)}(self):") + out.inc + out.puts("pass") + } override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {} override def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {} + override def attributeSetter(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = { + if (attrName.isInstanceOf[InstanceIdentifier]) { + val name = publicMemberName(attrName) + + out.puts(s"@$name.setter") + out.puts(s"def $name(self, v):") + out.inc + handleAssignmentSimple(attrName, "v") + out.dec + } + } + + override def attrSetProperty(base: Ast.expr, propName: Identifier, value: String): Unit = { + out.puts(s"${expression(base)}.${publicMemberName(propName)} = $value") + } + override def universalDoc(doc: DocSpec): Unit = { val docStr = doc.summary match { case Some(summary) => @@ -179,9 +275,6 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.putsLines("", "\"\"\"" + docStr + refStr + "\"\"\"") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - out.puts(s"${privateMemberName(attrName)} = self._io.ensure_fixed_contents($contents)") - override def attrParseHybrid(leProc: () => Unit, beProc: () => Unit): Unit = { out.puts("if self._is_le:") out.inc @@ -224,13 +317,89 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) importList.add(s"import $pkgName") s"$pkgName.${type2class(name.last)}" } - out.puts(s"_process = $procClass(${args.map(expression).mkString(", ")})") s"_process.decode($srcExpr)" } handleAssignment(varDest, expr, rep, false) } + override def attrUnprocess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec, dt: BytesType, exprTypeOpt: Option[DataType]): Unit = { + val srcExpr = varSrc match { + // use `_raw_items[_raw_items.size - 1]` + case _: RawIdentifier => getRawIdExpr(varSrc, rep) + // but `items[_index]` + case _ => expression(itemExpr(varSrc, rep)) + } + + val expr = proc match { + case ProcessXor(xorValue) => + val argStr = if (inSubIOWriteBackHandler) "_process_val" else expression(xorValue) + val procName = translator.detectType(xorValue) match { + case _: IntType => "process_xor_one" + case _: BytesType => "process_xor_many" + } + s"$kstreamName.$procName($srcExpr, $argStr)" + case ProcessZlib => + importList.add("import zlib") + s"zlib.compress($srcExpr)" + case ProcessRotate(isLeft, rotValue) => + val argStr = if (inSubIOWriteBackHandler) "_process_val" else expression(rotValue) + val expr = if (!isLeft) { + argStr + } else { + s"8 - ($argStr)" + } + s"$kstreamName.process_rotate_left($srcExpr, $expr, 1)" + case ProcessCustom(name, args) => + val procClass = if (name.length == 1) { + val onlyName = name.head + val className = type2class(onlyName) + importList.add(s"from $onlyName import $className") + className + } else { + val pkgName = name.init.mkString(".") + importList.add(s"import $pkgName") + s"$pkgName.${type2class(name.last)}" + } + + val procName = if (inSubIOWriteBackHandler) { + "_process_val" + } else { + val procName = s"_process_${idToStr(varSrc)}" + out.puts(s"$procName = $procClass(${args.map(expression).mkString(", ")})") + procName + } + s"$procName.encode($srcExpr)" + } + handleAssignment(varDest, expr, rep, false) + } + + override def attrUnprocessPrepareBeforeSubIOHandler(proc: ProcessExpr, varSrc: Identifier): Unit = { + // NOTE: the local variable "_process_val" will be captured in a default value of a parameter + // when defining the "write back handler" function (in subIOWriteBackHeader), see + // https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result + proc match { + case ProcessXor(xorValue) => + out.puts(s"_process_val = ${expression(xorValue)}") + case ProcessRotate(_, rotValue) => + out.puts(s"_process_val = ${expression(rotValue)}") + case ProcessZlib => // no process arguments + case ProcessCustom(name, args) => + val procClass = if (name.length == 1) { + val onlyName = name.head + val className = type2class(onlyName) + importList.add(s"from $onlyName import $className") + className + } else { + val pkgName = name.init.mkString(".") + importList.add(s"import $pkgName") + s"$pkgName.${type2class(name.last)}" + } + val procName = "_process_val" + out.puts(s"$procName = $procClass(${args.map(expression).mkString(", ")})") + } + } + override def normalIO: String = "self._io" override def allocateIO(varName: Identifier, rep: RepeatSpec): String = { @@ -243,11 +412,53 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) ioName } + override def allocateIOFixed(varName: Identifier, size: String): String = { + val varStr = privateMemberName(varName) + val ioName = s"_io_${idToStr(varName)}" + + // NOTE: in Python 2, bytes() converts an integer argument to a string (e.g. bytes(12) => '12'), + // so we have to use bytearray() instead + out.puts(s"$ioName = $kstreamName(BytesIO(bytearray($size)))") + ioName + } + + override def exprIORemainingSize(io: String): String = + s"$io.size() - $io.pos()" + + override def subIOWriteBackHeader(subIO: String, process: Option[ProcessExpr]): String = { + val parentIoName = "parent" + // NOTE: local variables "$subIO" and "_process_val" are captured here as default values of + // "handler" parameters, see + // https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result + val processValArg = + process.map(proc => proc match { + case _: ProcessXor | _: ProcessRotate | _: ProcessCustom => + ", _process_val=_process_val" + case _ => + "" + }).getOrElse("") + out.puts(s"def handler(parent, $subIO=$subIO$processValArg):") + out.inc + + inSubIOWriteBackHandler = true + + parentIoName + } + + override def subIOWriteBackFooter(subIO: String): Unit = { + inSubIOWriteBackHandler = false + + out.dec + out.puts(s"$subIO.write_back_handler = $kstreamName.WriteBackHandler(_pos2, handler)") + } + + override def addChildIO(io: String, childIO: String): Unit = + out.puts(s"$io.add_child_stream($childIO)") + def getRawIdExpr(varName: Identifier, rep: RepeatSpec): String = { val memberName = privateMemberName(varName) rep match { case NoRepeat => memberName - case RepeatExpr(_) => s"$memberName[i]" case _ => s"$memberName[-1]" } } @@ -260,14 +471,22 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def pushPos(io: String): Unit = out.puts(s"_pos = $io.pos()") + override def pushPosForSubIOWriteBackHandler(io: String): Unit = + out.puts(s"_pos2 = $io.pos()") + override def seek(io: String, pos: Ast.expr): Unit = out.puts(s"$io.seek(${expression(pos)})") + override def seekRelative(io: String, relPos: String): Unit = + out.puts(s"$io.seek($io.pos() + ($relPos))") + override def popPos(io: String): Unit = out.puts(s"$io.seek(_pos)") - override def alignToByte(io: String): Unit = - out.puts(s"$io.align_to_byte()") + // NOTE: the compiler does not need to output align_to_byte() calls for Python anymore, + // since the byte alignment is handled by the runtime library since commit + // https://github.com/kaitai-io/kaitai_struct_python_runtime/commit/1cb84b84d358e1cdffe35845d1e6688bff923952 + override def alignToByte(io: String): Unit = {} override def attrDebugStart(attrId: Identifier, attrType: DataType, ios: Option[String], rep: RepeatSpec): Unit = { ios.foreach { (io) => @@ -299,6 +518,7 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def condIfHeader(expr: Ast.expr): Unit = { out.puts(s"if ${expression(expr)}:") out.inc + out.puts("pass") } override def condRepeatInitAttr(id: Identifier, dataType: DataType): Unit = @@ -321,6 +541,16 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts(s"for i in range(${expression(repeatExpr)}):") out.inc } + + // used for all repetitions in _check() + override def condRepeatCommonHeader(id: Identifier, io: String, dataType: DataType): Unit = { + // TODO: replace range(len()) with enumerate() (see + // https://pylint.readthedocs.io/en/v2.16.2/user_guide/messages/convention/consider-using-enumerate.html) + out.puts(s"for i in range(len(${privateMemberName(id)})):") + out.inc + out.puts("pass") + } + override def handleAssignmentRepeatExpr(id: Identifier, expr: String): Unit = handleAssignmentRepeatEos(id, expr) @@ -388,8 +618,9 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Int], include: Boolean) = { val expr1 = padRight match { - case Some(padByte) => s"$kstreamName.bytes_strip_right($expr0, $padByte)" - case None => expr0 + case Some(padByte) if terminator.map(term => padByte != term).getOrElse(true) => + s"$kstreamName.bytes_strip_right($expr0, $padByte)" + case _ => expr0 } val expr2 = terminator match { case Some(term) => s"$kstreamName.bytes_terminate($expr1, $term, ${bool2Py(include)})" @@ -415,11 +646,13 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def switchIfCaseFirstStart(condition: Ast.expr): Unit = { out.puts(s"if _on == ${expression(condition)}:") out.inc + out.puts("pass") } override def switchIfCaseStart(condition: Ast.expr): Unit = { out.puts(s"elif _on == ${expression(condition)}:") out.inc + out.puts("pass") } override def switchIfCaseEnd(): Unit = @@ -428,10 +661,28 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def switchIfElseStart(): Unit = { out.puts(s"else:") out.inc + out.puts("pass") } override def switchIfEnd(): Unit = {} + override def instanceWriteFlagDeclaration(attrName: InstanceIdentifier): Unit = {} + + override def instanceWriteFlagInit(attrName: InstanceIdentifier): Unit = { + instanceClearWriteFlag(attrName) + out.puts(s"self.${publicMemberName(attrName)}__to_write = True") + } + + override def instanceSetWriteFlag(instName: InstanceIdentifier): Unit = { + out.puts(s"self._should_write_${publicMemberName(instName)} = self.${publicMemberName(instName)}__to_write") + } + + override def instanceClearWriteFlag(instName: InstanceIdentifier): Unit = { + out.puts(s"self._should_write_${publicMemberName(instName)} = False") + } + + override def instanceToWriteSetter(instName: InstanceIdentifier): Unit = {} + override def instanceHeader(className: String, instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit = { out.puts("@property") out.puts(s"def ${publicMemberName(instName)}(self):") @@ -446,11 +697,25 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts } + override def instanceCheckWriteFlagAndWrite(instName: InstanceIdentifier): Unit = { + out.puts(s"if self._should_write_${publicMemberName(instName)}:") + out.inc + out.puts(s"self._write_${publicMemberName(instName)}()") + out.dec + } + override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = { // workaround to avoid Python raising an "AttributeError: instance has no attribute" out.puts(s"return getattr(self, '${idToStr(instName)}', None)") } + override def instanceInvalidate(instName: InstanceIdentifier): Unit = { + out.puts(s"def _invalidate_${publicMemberName(instName)}(self):") + out.inc + out.puts(s"del ${privateMemberName(instName)}") + out.dec + } + override def enumDeclaration(curClass: String, enumName: String, enumColl: Seq[(Long, String)]): Unit = { importList.add("from enum import IntEnum") @@ -461,6 +726,9 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.dec } + override def internalEnumIntType(basedOn: IntType): DataType = + basedOn + override def debugClassSequence(seq: List[AttrSpec]) = { val seqStr = seq.map((attr) => "\"" + idToStr(attr.id) + "\"").mkString(", ") out.puts(s"SEQ_FIELDS = [$seqStr]") @@ -474,6 +742,107 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.dec } + override def attrPrimitiveWrite( + io: String, + valueExpr: Ast.expr, + dataType: DataType, + defEndian: Option[FixedEndian], + exprTypeOpt: Option[DataType] + ): Unit = { + val expr = expression(valueExpr) + + val stmt = dataType match { + case t: ReadableType => + s"$io.write_${t.apiCall(defEndian)}($expr)" + case BitsType1(bitEndian) => + s"$io.write_bits_int_${bitEndian.toSuffix}(1, ${translator.boolToInt(valueExpr)})" + case BitsType(width: Int, bitEndian) => + s"$io.write_bits_int_${bitEndian.toSuffix}($width, $expr)" + case _: BytesType => + s"$io.write_bytes($expr)" + } + out.puts(stmt) + } + + override def attrBytesLimitWrite(io: String, expr: Ast.expr, size: String, term: Int, padRight: Int): Unit = + out.puts(s"$io.write_bytes_limit(${expression(expr)}, $size, $term, $padRight)") + + override def attrUserTypeInstreamWrite(io: String, valueExpr: Ast.expr, dataType: DataType, exprType: DataType) = { + val expr = expression(valueExpr) + out.puts(s"$expr._write__seq($io)") + } + + override def exprStreamToByteArray(io: String): String = + s"$io.to_byte_array()" + + override def attrBasicCheck(checkExpr: Ast.expr, actual: Ast.expr, expected: Ast.expr, msg: String): Unit = { + val msgStr = expression(Ast.expr.Str(msg)) + + out.puts(s"if ${expression(checkExpr)}:") + out.inc + out.puts(s"raise kaitaistruct.ConsistencyError($msgStr, ${expression(actual)}, ${expression(expected)})") + out.dec + } + + override def attrObjectsEqualCheck(actual: Ast.expr, expected: Ast.expr, msg: String): Unit = { + val msgStr = expression(Ast.expr.Str(msg)) + + out.puts(s"if ${expression(actual)} != ${expression(expected)}:") + out.inc + out.puts(s"raise kaitaistruct.ConsistencyError($msgStr, ${expression(actual)}, ${expression(expected)})") + out.dec + } + + override def attrParentParamCheck(actualParentExpr: Ast.expr, ut: UserType, shouldDependOnIo: Option[Boolean], msg: String): Unit = { + if (ut.isOpaque) + return + /** @note Must be kept in sync with [[PythonCompiler.parseExpr]] */ + val (expectedParent, dependsOnIo) = ut.forcedParent match { + case Some(USER_TYPE_NO_PARENT) => ("None", false) + case Some(fp) => + (expression(fp), userExprDependsOnIo(fp)) + case None => ("self", false) + } + if (shouldDependOnIo.map(shouldDepend => dependsOnIo != shouldDepend).getOrElse(false)) + return + + val msgStr = expression(Ast.expr.Str(msg)) + + out.puts(s"if ${expression(actualParentExpr)} != $expectedParent:") + out.inc + out.puts(s"raise kaitaistruct.ConsistencyError($msgStr, ${expression(actualParentExpr)}, $expectedParent)") + out.dec + } + + override def attrIsEofCheck(io: String, expectedIsEof: Boolean, msg: String): Unit = { + val msgStr = expression(Ast.expr.Str(msg)) + + val eofExpr = s"$io.is_eof()" + val ifExpr = if (expectedIsEof) { + s"not $eofExpr" + } else { + eofExpr + } + out.puts(s"if $ifExpr:") + out.inc + out.puts(s"raise kaitaistruct.ConsistencyError($msgStr, ${exprIORemainingSize(io)}, 0)") + out.dec + } + + override def condIfIsEofHeader(io: String, wantedIsEof: Boolean): Unit = { + val eofExpr = s"$io.is_eof()" + val ifExpr = if (!wantedIsEof) { + s"not $eofExpr" + } else { + eofExpr + } + + out.puts(s"if $ifExpr:") + out.inc + } + + override def condIfIsEofFooter: Unit = universalFooter + def bool2Py(b: Boolean): String = if (b) { "True" } else { "False" } override def idToStr(id: Identifier): String = PythonCompiler.idToStr(id) @@ -487,19 +856,30 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def ksErrorName(err: KSError): String = PythonCompiler.ksErrorName(err) override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") + val errArgsStr = expected.map(expression) ++ List( + expression(Ast.expr.InternalName(attr.id)), + if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "None", + expression(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + ) out.puts(s"if not ${translator.translate(checkExpr)}:") out.inc - out.puts(s"raise ${ksErrorName(err)}($errArgsStr)") + out.puts(s"raise ${ksErrorName(err)}(${errArgsStr.mkString(", ")})") out.dec } + def kstructNameFull: String = { + ((config.autoRead, config.readWrite) match { + case (_, true) => "ReadWrite" + case (_, false) => "" + }) + kstructName + } + def userType2class(t: UserType): String = { val name = t.classSpec.get.name val firstName = name.head @@ -528,6 +908,8 @@ object PythonCompiler extends LanguageCompilerStatic case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx" case InstanceIdentifier(name) => s"_m_$name" case RawIdentifier(innerId) => s"_raw_${idToStr(innerId)}" + case OuterSizeIdentifier(innerId) => s"${idToStr(innerId)}__outer_size" + case InnerSizeIdentifier(innerId) => s"${idToStr(innerId)}__inner_size" } def publicMemberName(id: Identifier): String = diff --git a/shared/src/main/scala/io/kaitai/struct/languages/RubyCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/RubyCompiler.scala index 6ce5d98e9..b38f572c4 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/RubyCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/RubyCompiler.scala @@ -18,7 +18,6 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with UpperCamelCaseClasses with AllocateIOLocalVar with EveryReadIsExpression - with FixedContentsUsingArrayByteLiteral with NoNeedForFullClassPath { import RubyCompiler._ @@ -174,9 +173,6 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("end") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - out.puts(s"${privateMemberName(attrName)} = $normalIO.ensure_fixed_contents($contents)") - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) @@ -460,14 +456,18 @@ class RubyCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) override def ksErrorName(err: KSError): String = RubyCompiler.ksErrorName(err) override def attrValidateExpr( - attrId: Identifier, - attrType: DataType, + attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, - errArgs: List[Ast.expr] + useIo: Boolean, + expected: Option[Ast.expr] = None ): Unit = { - val errArgsStr = errArgs.map(translator.translate).mkString(", ") - out.puts(s"raise ${ksErrorName(err)}.new($errArgsStr) if not ${translator.translate(checkExpr)}") + val errArgsStr = expected.map(expression) ++ List( + expression(Ast.expr.InternalName(attr.id)), + if (useIo) expression(Ast.expr.InternalName(IoIdentifier)) else "nil", + expression(Ast.expr.Str(attr.path.mkString("/", "/", ""))) + ) + out.puts(s"raise ${ksErrorName(err)}.new(${errArgsStr.mkString(", ")}) if not ${translator.translate(checkExpr)}") } def types2class(names: List[String]) = names.map(type2class).mkString("::") diff --git a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala index 0e55cb402..8afcc65cb 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala @@ -16,7 +16,6 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) with AllocateIOLocalVar with UniversalFooter with UniversalDoc - with FixedContentsUsingArrayByteLiteral with EveryReadIsExpression { import RustCompiler._ @@ -170,9 +169,6 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig) out.puts("}") } - override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit = - out.puts(s"${privateMemberName(attrName)} = $normalIO.ensureFixedContents($contents);") - override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit = { val srcExpr = getRawIdExpr(varSrc, rep) diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateIOLocalVar.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateIOLocalVar.scala index fd84dce12..32cb79047 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateIOLocalVar.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/AllocateIOLocalVar.scala @@ -9,5 +9,21 @@ import io.kaitai.struct.format.{AttrSpec, Identifier, RepeatSpec} trait AllocateIOLocalVar extends ExtraAttrs { def allocateIO(varName: Identifier, rep: RepeatSpec): String + /** + * Allocates a fixed-size KaitaiStream IO object for writing into it. + * @param varName variable name to use to generate IO name + * @param size size expression to use + * @return name of generated IO local variable as string + */ + def allocateIOFixed(varName: Identifier, size: String): String = ??? + + /** + * Allocates a growing KaitaiStream IO object for writing into it. + * This one is expected to grow indefinitely as one writes more data + * into it, never raising a buffer overflow exception + * @param varName variable name to use to generate IO name + * @return name of generated IO local variable as string + */ + def allocateIOGrowing(varName: Identifier): String = ??? override def extraAttrForIO(id: Identifier, rep: RepeatSpec): List[AttrSpec] = List() } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/CommonReads.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/CommonReads.scala index 0640b3af5..41ae7adf6 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/CommonReads.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/CommonReads.scala @@ -101,5 +101,5 @@ trait CommonReads extends LanguageCompiler { * @param attr attribute to run validations for */ def attrValidateAll(attr: AttrLikeSpec) = - attr.valid.foreach(valid => attrValidate(attr.id, attr, valid)) + attr.valid.foreach(valid => attrValidate(attr.id, attr, valid, true)) } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/EveryReadIsExpression.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/EveryReadIsExpression.scala index 1bc427c6f..97dd86ab2 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/EveryReadIsExpression.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/EveryReadIsExpression.scala @@ -186,6 +186,36 @@ trait EveryReadIsExpression attrParse2(rawId, byteType, io, rep, true, defEndian) + if (config.readWrite) { + if (writeNeedsOuterSize(byteType)) { + /** @note Must be kept in sync with [[attrBytesTypeParse]] */ + val rawRawId = byteType.process match { + case None => rawId + case Some(_) => RawIdentifier(rawId) + } + val item = itemExpr(rawRawId, rep) + val itemSizeExprStr = expression(Ast.expr.Attribute(item, Ast.identifier("size"))) + /** FIXME: cannot use [[handleAssignment]] because [[handleAssignmentRepeatUntil]] + * always tries to assign the value to the [[Identifier.ITERATOR]] variable */ + if (rep == NoRepeat) { + handleAssignmentSimple(OuterSizeIdentifier(id), itemSizeExprStr) + } else { + handleAssignmentRepeatEos(OuterSizeIdentifier(id), itemSizeExprStr) + } + } + if (writeNeedsInnerSize(byteType)) { + val item = itemExpr(rawId, rep) + val itemSizeExprStr = expression(Ast.expr.Attribute(item, Ast.identifier("size"))) + /** FIXME: cannot use [[handleAssignment]] because [[handleAssignmentRepeatUntil]] + * always tries to assign the value to the [[Identifier.ITERATOR]] variable */ + if (rep == NoRepeat) { + handleAssignmentSimple(InnerSizeIdentifier(id), itemSizeExprStr) + } else { + handleAssignmentRepeatEos(InnerSizeIdentifier(id), itemSizeExprStr) + } + } + } + val extraType = rep match { case NoRepeat => byteType case _ => ArrayTypeInStream(byteType) @@ -265,4 +295,17 @@ trait EveryReadIsExpression case _ => super.attrDebugNeeded(attrId) } } + + def itemExpr(id: Identifier, rep: RepeatSpec): Ast.expr = { + val astId = Ast.expr.InternalName(id) + rep match { + case NoRepeat => + astId + case _ => + Ast.expr.Subscript( + astId, + Ast.expr.Name(Ast.identifier(Identifier.INDEX)) + ) + } + } } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/EveryWriteIsExpression.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/EveryWriteIsExpression.scala new file mode 100644 index 000000000..dec964d13 --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/EveryWriteIsExpression.scala @@ -0,0 +1,449 @@ +package io.kaitai.struct.languages.components + +import io.kaitai.struct.datatype.DataType +import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.datatype.FixedEndian +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.format._ + +import scala.collection.mutable.ListBuffer +import io.kaitai.struct.datatype._ + +trait EveryWriteIsExpression + extends LanguageCompiler + with ObjectOrientedLanguage + with EveryReadIsExpression + with GenericChecks { + override def attrWrite(attr: AttrLikeSpec, id: Identifier, defEndian: Option[Endianness]): Unit = { + val checksShouldDependOnIo: Option[Boolean] = + if (userExprDependsOnIo(attr.cond.ifExpr)) { + None + } else { + Some(true) + } + + attrParseIfHeader(id, attr.cond.ifExpr) + + // Manage IO & seeking for ParseInstances + val io = attr match { + case pis: ParseInstanceSpec => + val io = pis.io match { + case None => normalIO + case Some(ex) => useIO(ex) + } + pis.pos.foreach { pos => + pushPos(io) + seek(io, pos) + } + io + case _ => + // no seeking required for sequence attributes + normalIO + } + + defEndian match { + case Some(_: CalcEndian) | Some(InheritedEndian) => + // FIXME: rename to indicate that it can be used for both parsing/writing + attrParseHybrid( + () => attrWrite0(id, attr, io, Some(LittleEndian), checksShouldDependOnIo), + () => attrWrite0(id, attr, io, Some(BigEndian), checksShouldDependOnIo) + ) + case None => + attrWrite0(id, attr, io, None, checksShouldDependOnIo) + case Some(fe: FixedEndian) => + attrWrite0(id, attr, io, Some(fe), checksShouldDependOnIo) + } + + attr match { + case pis: ParseInstanceSpec => + // Restore position, if applicable + if (pis.pos.isDefined) + popPos(io) + case _ => // no seeking required for sequence attributes + } + + attrParseIfFooter(attr.cond.ifExpr) + } + + def attrWrite0( + id: Identifier, + attr: AttrLikeSpec, + io: String, + defEndian: Option[FixedEndian], + checksShouldDependOnIo: Option[Boolean] + ): Unit = { + if (attr.cond.repeat != NoRepeat) + ExtraAttrs.forAttr(attr, this) + .filter(a => a.id.isInstanceOf[RawIdentifier]) + .foreach(a => condRepeatInitAttr(a.id, a.dataType)) + attr.cond.repeat match { + case RepeatEos => + case RepeatExpr(repeatExpr: Ast.expr) => + attrRepeatExprCheck(id, repeatExpr, checksShouldDependOnIo) + case RepeatUntil(untilExpr: Ast.expr) => + if (checksShouldDependOnIo.map(shouldDepend => shouldDepend == false).getOrElse(true)) + attrAssertUntilNotEmpty(id) + case NoRepeat => + } + if (attr.cond.repeat != NoRepeat) { + condRepeatCommonHeader(id, io, attr.dataType) + } + attr.cond.repeat match { + case RepeatEos => + attrIsEofCheck(id, false, io) + case _ => + } + attrWrite2(id, attr.dataType, io, attr.cond.repeat, false, defEndian, checksShouldDependOnIo) + attr.cond.repeat match { + case repUntil: RepeatUntil => + attrAssertUntilCond(id, attr.dataType, repUntil, checksShouldDependOnIo) + case _ => + } + if (attr.cond.repeat != NoRepeat) { + condRepeatCommonFooter + } + attr.cond.repeat match { + case RepeatEos => + attrIsEofCheck(id, true, io) + case _ => + } + } + + def attrWrite2( + id: Identifier, + dataType: DataType, + io: String, + rep: RepeatSpec, + isRaw: Boolean, + defEndian: Option[FixedEndian], + checksShouldDependOnIo: Option[Boolean], + exprTypeOpt: Option[DataType] = None + ): Unit = { + dataType match { + case t: UserType => + attrUserTypeWrite(id, t, io, rep, isRaw, defEndian, checksShouldDependOnIo, exprTypeOpt) + case t: BytesType => + attrBytesTypeWrite(id, t, io, rep, isRaw, checksShouldDependOnIo, exprTypeOpt) + case st: SwitchType => + val isNullable = if (switchBytesOnlyAsRaw) { + st.isNullableSwitchRaw + } else { + st.isNullable + } + + attrSwitchTypeWrite(id, st.on, st.cases, io, rep, defEndian, checksShouldDependOnIo, st.combinedType) + case t: StrFromBytesType => + attrStrTypeWrite(id, t, io, rep, isRaw, checksShouldDependOnIo, exprTypeOpt) + case t: EnumType => + val expr = itemExpr(id, rep) + val exprType = internalEnumIntType(t.basedOn) + attrPrimitiveWrite(io, Ast.expr.Attribute(expr, Ast.identifier("to_i")), t.basedOn, defEndian, Some(exprType)) + case _ => + val expr = itemExpr(id, rep) + attrPrimitiveWrite(io, expr, dataType, defEndian, exprTypeOpt) + } + } + + def attrBytesTypeWrite( + id: Identifier, + t: BytesType, + io: String, + rep: RepeatSpec, + isRaw: Boolean, + checksShouldDependOnIo: Option[Boolean], + exprTypeOpt: Option[DataType] + ): Unit = { + val idToWrite = t.process match { + case Some(proc) => + val rawId = RawIdentifier(id) + attrUnprocess(proc, id, rawId, rep, t, exprTypeOpt) + rawId + case None => + id + } + val item = if (idToWrite.isInstanceOf[RawIdentifier] && rep != NoRepeat) { + // NOTE: This special handling isn't normally needed and one can just use + // `itemExpr(idToWrite, rep)` as usual. The `itemExpr` method assumes that the + // expression it's supposed to generate will be used in a loop where the iteration + // variable `Identifier.INDEX` is available (usually called just `i`) and uses it. This + // is a good default, but it doesn't work if the expression is used between + // `subIOWriteBackHeader` and `subIOWriteBackFooter` (see `attrUserTypeWrite` below), + // because in Java the loop control variable `i` is not "final" or "effectively final". + // + // The workaround is to change the expression so that it doesn't depend on the `i` + // variable. We can do that here, because the `RawIdentifier(...)` array starts empty + // before the loop and each element is added by `attrUnprocess` in each loop iteration - + // so the current item is just the last entry in the `RawIdentifier(...)` array. + // + // See test ProcessRepeatUsertype that requires this. + val astId = Ast.expr.InternalName(idToWrite) + Ast.expr.Subscript( + astId, + Ast.expr.BinOp( + Ast.expr.Attribute( + astId, + Ast.identifier("size") + ), + Ast.operator.Sub, + Ast.expr.IntNum(1) + ) + ) + } else { + itemExpr(idToWrite, rep) + } + val itemBytes = + if (exprTypeOpt.map(exprType => !exprType.isInstanceOf[BytesType]).getOrElse(false)) + Ast.expr.CastToType(item, Ast.typeId(false, Seq("bytes"))) + else + item + attrBytesTypeWrite2(id, io, itemBytes, t, checksShouldDependOnIo, exprTypeOpt) + } + + def attrStrTypeWrite( + id: Identifier, + t: StrFromBytesType, + io: String, + rep: RepeatSpec, + isRaw: Boolean, + checksShouldDependOnIo: Option[Boolean], + exprTypeOpt: Option[DataType] + ): Unit = { + val item = itemExpr(id, rep) + val itemStr = + if (exprTypeOpt.map(exprType => !exprType.isInstanceOf[StrType]).getOrElse(false)) + Ast.expr.CastToType(item, Ast.typeId(false, Seq("str"))) + else + item + val bytes = exprStrToBytes(itemStr, t.encoding) + attrBytesTypeWrite2(id, io, bytes, t.bytes, checksShouldDependOnIo, exprTypeOpt) + } + + def attrBytesTypeWrite2( + id: Identifier, + io: String, + expr: Ast.expr, + t: BytesType, + checksShouldDependOnIo: Option[Boolean], + exprTypeOpt: Option[DataType] + ): Unit = { + attrBytesCheck(id, expr, t, checksShouldDependOnIo) + t match { + case bt: BytesEosType => + attrBytesLimitWrite2(io, expr, bt, exprIORemainingSize(io), bt.padRight, bt.terminator, bt.include, exprTypeOpt) + attrIsEofCheck(id, true, io) + case bt: BytesLimitType => + attrBytesLimitWrite2(io, expr, bt, expression(bt.size), bt.padRight, bt.terminator, bt.include, exprTypeOpt) + case t: BytesTerminatedType => + attrPrimitiveWrite(io, expr, t, None, exprTypeOpt) + if (t.include) { + val actualIndexOfTerm = exprByteArrayIndexOf(expr, t.terminator) + if (!t.eosError) { + condIfHeader(Ast.expr.Compare(actualIndexOfTerm, Ast.cmpop.Eq, Ast.expr.IntNum(-1))) + attrIsEofCheck(id, true, io) + condIfFooter + } + } else { + if (!t.eosError) + condIfIsEofHeader(io, false) + + if (!t.consume) { + if (t.eosError) { + blockScopeHeader + } + pushPos(io) + } + attrPrimitiveWrite(io, Ast.expr.IntNum(t.terminator), Int1Type(false), None, None) + if (!t.consume) { + popPos(io) + if (t.eosError) { + blockScopeFooter + } + } + if (!t.eosError) + condIfIsEofFooter + } + } + } + + def attrBytesLimitWrite2( + io: String, + expr: Ast.expr, + bt: BytesType, + sizeExpr: String, + padRight: Option[Int], + terminator: Option[Int], + include: Boolean, + exprTypeOpt: Option[DataType] + ): Unit = { + val (termArg, padRightArg) = (terminator, padRight, include) match { + case (None, None, false) => + // no terminator, no padding => just a regular output + // validation should check that expression's length matches size + attrPrimitiveWrite(io, expr, bt, None, exprTypeOpt) + return + case (_, None, true) => + // terminator included, no padding => pad with zeroes + (0, 0) + case (_, Some(p), true) => + // terminator included, padding specified + (p, p) + case (Some(t), None, false) => + // only terminator given, don't care about what's gonna go after that + // we'll just pad with zeroes + (t, 0) + case (None, Some(p), false) => + // only padding given, just add terminator equal to padding + (p, p) + case (Some(t), Some(p), false) => + // both terminator and padding specified + (t, p) + } + attrBytesLimitWrite(io, expr, sizeExpr, termArg, padRightArg) + } + + def attrUserTypeWrite( + id: Identifier, + t: UserType, + io: String, + rep: RepeatSpec, + isRaw: Boolean, + defEndian: Option[FixedEndian], + checksShouldDependOnIo: Option[Boolean], + exprTypeOpt: Option[DataType] = None + ) = { + val exprType = exprTypeOpt.getOrElse(t) + val expr = itemExpr(id, rep) + + { + val itemUserType = + if (exprTypeOpt.map(exprType => !exprType.isInstanceOf[UserType]).getOrElse(false)) + Ast.expr.CastToType(expr, Ast.typeId(true, t.classSpec.get.name)) + else + expr + // check non-`io` params + attrUserTypeCheck(id, itemUserType, t, checksShouldDependOnIo) + // set `io` params + (t.classSpec.get.params, t.args).zipped.foreach { (paramDef, argExpr) => + val paramItemType = getArrayItemType(paramDef.dataType) + val paramBasedOnIo = (paramItemType == KaitaiStreamType || paramItemType == OwnedKaitaiStreamType) + if (paramBasedOnIo) + attrSetProperty(itemUserType, paramDef.id, expression(argExpr)) + } + } + + t match { + case _: UserTypeInstream => + attrUserTypeInstreamWrite(io, expr, t, exprType) + case utb: UserTypeFromBytes => + val rawId = RawIdentifier(id) + val byteType = utb.bytes + + /** @note Must be kept in sync with [[ExtraAttrs.writeNeedsOuterSize]] */ + val outerSize = byteType match { + case blt: BytesLimitType => + translator.translate(blt.size) + case _: BytesEosType => + exprIORemainingSize(io) + case _: BytesTerminatedType => + translator.translate(itemExpr(OuterSizeIdentifier(id), rep)) + } + + /** @note Must be kept in sync with [[ExtraAttrs.writeNeedsInnerSize]] */ + val innerSize = if (writeNeedsInnerSize(utb.bytes)) { + translator.translate(itemExpr(InnerSizeIdentifier(id), rep)) + } else { + outerSize + } + + this match { + // case thisStore: AllocateAndStoreIO => + // val ourIO = thisStore.allocateIO(rawId, rep) + // Utils.addUniqueAttr(extraAttrs, AttrSpec(List(), ourIO, KaitaiStreamType)) + // privateMemberName(ourIO) + case thisLocal: AllocateIOLocalVar => + val ioFixed = thisLocal.allocateIOFixed(rawId, innerSize) + addChildIO(io, ioFixed) + + blockScopeHeader + + pushPosForSubIOWriteBackHandler(io) + seekRelative(io, outerSize) + byteType match { + case t: BytesTerminatedType => + if (!t.include && t.consume) { + if (!t.eosError) + condIfIsEofHeader(io, false) + // terminator can only be 1 byte long at the moment + seekRelative(io, expression(Ast.expr.IntNum(1))) + if (!t.eosError) + condIfIsEofFooter + } + case _ => // do nothing + } + + byteType.process.foreach { (process) => + attrUnprocessPrepareBeforeSubIOHandler(process, rawId) + } + + { + val parentIO = subIOWriteBackHeader(ioFixed, byteType.process) + handleAssignment(rawId, exprStreamToByteArray(ioFixed), rep, true) + attrBytesTypeWrite(rawId, byteType, parentIO, rep, isRaw, None, exprTypeOpt) + subIOWriteBackFooter(ioFixed) + } + + blockScopeFooter + + attrUserTypeInstreamWrite(ioFixed, expr, t, exprType) + } + } + } + def attrSwitchTypeWrite( + id: Identifier, + on: Ast.expr, + cases: Map[Ast.expr, DataType], + io: String, + rep: RepeatSpec, + defEndian: Option[FixedEndian], + checksShouldDependOnIoOrig: Option[Boolean], + assignType: DataType + ): Unit = { + val checksShouldDependOnIo = + if (userExprDependsOnIo(on)) { + None + } else { + checksShouldDependOnIoOrig + } + + switchCases[DataType](id, on, cases, + (dataType) => { + attrWrite2(id, dataType, io, rep, false, defEndian, checksShouldDependOnIo, Some(assignType)) + }, + (dataType) => if (switchBytesOnlyAsRaw) { + dataType match { + case t: BytesType => + attrWrite2(RawIdentifier(id), dataType, io, rep, false, defEndian, checksShouldDependOnIo, Some(assignType)) + case _ => + attrWrite2(id, dataType, io, rep, false, defEndian, checksShouldDependOnIo, Some(assignType)) + } + } else { + attrWrite2(id, dataType, io, rep, false, defEndian, checksShouldDependOnIo, Some(assignType)) + } + ) + } + + def internalEnumIntType(basedOn: IntType): DataType + + def attrPrimitiveWrite(io: String, expr: Ast.expr, dt: DataType, defEndian: Option[FixedEndian], exprTypeOpt: Option[DataType]): Unit + def attrBytesLimitWrite(io: String, expr: Ast.expr, size: String, term: Int, padRight: Int): Unit + def attrUserTypeInstreamWrite(io: String, expr: Ast.expr, t: DataType, exprType: DataType): Unit + def exprStreamToByteArray(ioFixed: String): String + + def attrUnprocess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec, dt: BytesType, exprTypeOpt: Option[DataType]): Unit + def attrUnprocessPrepareBeforeSubIOHandler(proc: ProcessExpr, varSrc: Identifier): Unit + + def condIfIsEofHeader(io: String, wantedIsEof: Boolean): Unit + def condIfIsEofFooter: Unit + + def attrSetProperty(base: Ast.expr, propName: Identifier, value: String): Unit +} diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/ExtraAttrs.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/ExtraAttrs.scala index 2799bb53e..487bfa45e 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/ExtraAttrs.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/ExtraAttrs.scala @@ -3,12 +3,15 @@ package io.kaitai.struct.languages.components import io.kaitai.struct.datatype.DataType import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.format._ +import io.kaitai.struct.RuntimeConfig /** * Trait to be implemented by all [[LanguageCompiler]] compilers: supplies extra attributes * when we'll be allocating new IOs. */ trait ExtraAttrs { + val config: RuntimeConfig + /** * Provides a collection of extra attributes which will be necessary to store in a class for * handling of a single "normal" attribute. Primarily @@ -29,8 +32,24 @@ trait ExtraAttrs { } case utb: UserTypeFromBytes => // User type in a substream + val dynamicSizeAttributes: List[AttrSpec] = if (config.readWrite) { + val outerSizeOpt: Option[AttrSpec] = if (writeNeedsOuterSize(utb.bytes)) { + Some(AttrSpec(List(), OuterSizeIdentifier(id), CalcIntType, condSpec)) + } else { + None + } + val innerSizeOpt: Option[AttrSpec] = if (writeNeedsInnerSize(utb.bytes)) { + Some(AttrSpec(List(), InnerSizeIdentifier(id), CalcIntType, condSpec)) + } else { + None + } + List(innerSizeOpt, outerSizeOpt).flatten + } else { + List() + } val rawId = RawIdentifier(id) (extraRawAttrForUserTypeFromBytes(id, utb, condSpec) ++ + dynamicSizeAttributes ++ extraAttrForIO(rawId, condSpec.repeat) ++ extraAttrsForAttribute(rawId, utb.bytes, condSpec)).toList.distinct case st: SwitchType => @@ -59,15 +78,39 @@ trait ExtraAttrs { List(AttrSpec(List(), RawIdentifier(id), ut.bytes, condSpec)) def extraAttrForIO(id: Identifier, rep: RepeatSpec): List[AttrSpec] + def writeNeedsOuterSize(bytes: BytesType): Boolean = { + bytes match { + case _: BytesTerminatedType => true + case _ => false + } + } + def writeNeedsInnerSize(bytes: BytesType): Boolean = { + val unknownInnerSizeProcess = bytes.process match { + case Some(process) => process match { + case ProcessZlib | _: ProcessCustom => true + case _: ProcessXor | _: ProcessRotate => false + } + case None => false + } + val unknownInnerSizePadTerm = bytes match { + case bt: BytesLimitType => + bt.padRight.isDefined || bt.terminator.isDefined + case bt: BytesEosType => + bt.padRight.isDefined || bt.terminator.isDefined + case _ => false + } + unknownInnerSizeProcess || unknownInnerSizePadTerm + } } /** * Generates list of extra attributes required to store intermediate / * virtual stuff for every attribute like: * - * * buffered raw value byte arrays - * * IO objects (?) - * * unprocessed / postprocessed byte arrays + * - buffered raw value byte arrays + * - IO objects (?) + * - unprocessed / postprocessed byte arrays + * - outer and inner sizes of fields with substreams */ object ExtraAttrs { def forClassSpec(curClass: ClassSpec, compiler: ExtraAttrs): List[AttrSpec] = { diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/FetchInstances.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/FetchInstances.scala new file mode 100644 index 000000000..9f99bf99d --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/FetchInstances.scala @@ -0,0 +1,53 @@ +package io.kaitai.struct.languages.components + +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.format._ +import io.kaitai.struct.datatype.DataType +import io.kaitai.struct.datatype.DataType._ + +trait FetchInstances extends LanguageCompiler with ObjectOrientedLanguage with EveryReadIsExpression { + override def attrFetchInstances(attr: AttrLikeSpec, id: Identifier): Unit = { + attrParseIfHeader(id, attr.cond.ifExpr) + + val io = normalIO + + id match { + case instName: InstanceIdentifier => + attrInvokeInstance(instName) + case _ => + } + + if (attr.cond.repeat != NoRepeat) + condRepeatCommonHeader(id, io, attr.dataType) + + attrFetchInstances2(id, attr.dataType, attr.cond.repeat) + + if (attr.cond.repeat != NoRepeat) + condRepeatCommonFooter + + attrParseIfFooter(attr.cond.ifExpr) + } + + def attrFetchInstances2(id: Identifier, dataType: DataType, rep: RepeatSpec, exprTypeOpt: Option[DataType] = None): Unit = { + dataType match { + case _: UserType => + val exprType = exprTypeOpt.getOrElse(dataType) + attrInvokeFetchInstances(itemExpr(id, rep), exprType, dataType) + case st: SwitchType => + attrSwitchTypeFetchInstances(id, st.on, st.cases, rep, st.combinedType) + case _ => + } + } + + def attrSwitchTypeFetchInstances(id: Identifier, on: Ast.expr, cases: Map[Ast.expr, DataType], rep: RepeatSpec, assignType: DataType): Unit = { + switchCases[DataType](id, on, cases, + (dataType) => { + attrFetchInstances2(id, dataType, rep, Some(assignType)) + }, + (dataType) => { + // TODO: process switchBytesOnlyAsRaw + attrFetchInstances2(id, dataType, rep, Some(assignType)) + } + ) + } +} diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/FixedContentsUsingArrayByteLiteral.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/FixedContentsUsingArrayByteLiteral.scala deleted file mode 100644 index c5d4abd9b..000000000 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/FixedContentsUsingArrayByteLiteral.scala +++ /dev/null @@ -1,21 +0,0 @@ -package io.kaitai.struct.languages.components - -import io.kaitai.struct.exprlang.Ast -import io.kaitai.struct.format.Identifier - -/** - * Allows uniform implementation of attrFixedContentsParse by enforcing usage - * of doByteArrayLiteral in relevant language's translator. - */ -trait FixedContentsUsingArrayByteLiteral extends LanguageCompiler { - def attrFixedContentsParse(attrName: Identifier, contents: Array[Byte]) = - attrFixedContentsParse( - attrName, - translator.translate( - Ast.expr.List( - contents.map(x => Ast.expr.IntNum(BigInt(x & 0xff))) - ) - ) - ) - def attrFixedContentsParse(attrName: Identifier, contents: String): Unit -} diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/GenericChecks.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/GenericChecks.scala new file mode 100644 index 000000000..449f874cc --- /dev/null +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/GenericChecks.scala @@ -0,0 +1,537 @@ +package io.kaitai.struct.languages.components +import io.kaitai.struct.datatype.DataType +import io.kaitai.struct.datatype.DataType._ +import io.kaitai.struct.exprlang.Ast +import io.kaitai.struct.format._ + +trait GenericChecks extends LanguageCompiler with EveryReadIsExpression { + override def attrCheck(attr: AttrLikeSpec, id: Identifier): Unit = { + val bodyShouldDependOnIo: Option[Boolean] = + if (userExprDependsOnIo(attr.cond.ifExpr)) { + return + } else { + Some(false) + } + + attrParseIfHeader(id, attr.cond.ifExpr) + + val io = normalIO + + attr.cond.repeat match { + case RepeatEos => + condRepeatCommonHeader(id, io, attr.dataType) + attrCheck2(id, attr.dataType, attr.cond.repeat, bodyShouldDependOnIo) + condRepeatCommonFooter + case RepeatExpr(repeatExpr: Ast.expr) => + attrRepeatExprCheck(id, repeatExpr, bodyShouldDependOnIo) + condRepeatCommonHeader(id, io, attr.dataType) + attrCheck2(id, attr.dataType, attr.cond.repeat, bodyShouldDependOnIo) + condRepeatCommonFooter + case repUntil: RepeatUntil => + attrAssertUntilNotEmpty(id) + condRepeatCommonHeader(id, io, attr.dataType) + attrCheck2(id, attr.dataType, attr.cond.repeat, bodyShouldDependOnIo) + attrAssertUntilCond(id, attr.dataType, repUntil, bodyShouldDependOnIo) + condRepeatCommonFooter + case NoRepeat => + attrCheck2(id, attr.dataType, attr.cond.repeat, bodyShouldDependOnIo) + } + // TODO: move to attrCheck2 when we change the `valid` semantics to apply to each item + // in repeated fields, not the entire array (see + // https://github.com/kaitai-io/kaitai_struct_formats/issues/347) + attr.valid.foreach { (valid) => + typeProvider._currentIteratorType = Some(attr.dataTypeComposite) + if (bodyShouldDependOnIo.map(shouldDepend => validDependsOnIo(valid) == shouldDepend).getOrElse(true)) { + attrValidate(id, attr, valid, false) + } + } + + attrParseIfFooter(attr.cond.ifExpr) + } + + + + def userExprDependsOnIo(expr: Option[Ast.expr]): Boolean = expr match { + case None => false + case Some(v) => userExprDependsOnIo(v) + } + + def getArrayItemType(dt: DataType): DataType = { + dt match { + case arr: ArrayType => getArrayItemType(arr.elType) + case other => other + } + } + + def userExprDependsOnIo(expr: Ast.expr): Boolean = { + expr match { + case _: Ast.expr.IntNum => false + case _: Ast.expr.FloatNum => false + case _: Ast.expr.Str => false + case _: Ast.expr.Bool => false + case _: Ast.expr.EnumById => false + case _: Ast.expr.EnumByLabel => false + case n: Ast.expr.Name => + val t = getArrayItemType(translator.detectType(n)) + if (t == KaitaiStreamType || t == OwnedKaitaiStreamType) { + true + } else { + /** @see [[ClassTypeProvider.determineType(inClass:ClassSpec,attrName:String):DataType*]] */ + n.id.name match { + case Identifier.ROOT + | Identifier.PARENT + | Identifier.IO + | Identifier.ITERATOR + | Identifier.SWITCH_ON + | Identifier.INDEX + | Identifier.SIZEOF + => false + case _ => + val spec = typeProvider.resolveMember(typeProvider.nowClass, n.id.name) + spec match { + case _: AttrSpec => false + + // Parameters are fine because they are normally set by the user, so are already + // available in _check(). The only parameters set by the generated code in _write() + // (not by the user) are params of type KaitaiStream or an array of KaitaiStream, + // but these were caught earlier in this function. + case _: ParamDefSpec => false + + // Value instances are OK to use in _check() if their expressions in `value` and + // `if` do not use _io or parse instances. They can refer to other value instances, + // provided they follow the same conditions (which is ensured by a recursive call). + case vis: ValueInstanceSpec => + userExprDependsOnIo(vis.ifExpr) || userExprDependsOnIo(vis.value) + + // Although accessing parse instances in _check() is not a problem by itself, + // because parse instances are set by the user so they are already available in + // _check(), it becomes a problem when you don't invoke a parse instance dependent + // on the time of invocation in _write() because you have already done a particular + // check in _check(). + // + // Take the test + // https://github.com/kaitai-io/kaitai_struct_tests/blob/010efd1d9c07a61a320a644d4e782dd488ba28e4/formats/instance_in_repeat_until.ksy + // as an example. In _write() you don't need to reproduce the special `do { ... } + // while (!repeat-until);` loop as used in _read(), because you already know the + // array length, so a simple "foreach" loop will suffice. Then there is a + // consistency check to ensure that the `repeat-until` condition is `false` for all + // items except the last one, and `true` for the last one. This check can be either + // done in _check(), or in _write() at the end of each iteration of the "foreach" + // loop. You can do it in _check() if you want, but you *need* to evaluate the + // `repeat-until` expression (and throw away the result, if you like - the point is + // just to invoke the parse instances specified there) at the end of each "foreach" + // loop iteration in _write(), because _read() does that. So it makes sense to do + // the check only in _write(). + // + // It may be tempting to suggest to do the check both in _check() and _write(), and + // in this particular case you could really do that because the parse instance is + // used directly in the `repeat-until` expression. But if such parse instance is used + // indirectly via a value instance, you should no longer use that value instance in + // _check() at all, because that would cache the its value and the invocation in + // _write() would merely return this cached value, not evaluating the expression + // again. But that means that the parse instance will be written at a different + // time, because it won't be invoked from `seq` at the time it would be in _read() + // and will be written only when invoked from _fetchInstances(), which is wrong and + // inconsistent with parsing. Although the user could work around this specific + // issue by manually invalidating the value instances that the careless _check() + // invoked after calling it, this would be a bug in _check(). Calling _check() + // should not have side effects that the user has to "undo". + // + // Of course, perhaps most parse instances are not dependent on the time of + // invocation. But the language allows them to be, and it's not that trivial to + // detect it: you have to analyze all expressions that affect its parsing. So we + // will not do that for now - it's easier to avoid invoking parse instances in + // _check() (directly or indirectly) entirely. + case _: ParseInstanceSpec => true + } + } + } + case Ast.expr.InternalName(id) => + ??? + case Ast.expr.UnaryOp(op, inner) => + userExprDependsOnIo(inner) + case Ast.expr.Compare(left, op, right) => + userExprDependsOnIo(left) || userExprDependsOnIo(right) + case Ast.expr.BinOp(left, op, right) => + userExprDependsOnIo(left) || userExprDependsOnIo(right) + case Ast.expr.BoolOp(op, values) => + values.exists(v => userExprDependsOnIo(v)) + case Ast.expr.IfExp(condition, ifTrue, ifFalse) => + userExprDependsOnIo(condition) || userExprDependsOnIo(ifTrue) || userExprDependsOnIo(ifFalse) + case Ast.expr.Subscript(value, idx) => + userExprDependsOnIo(value) || userExprDependsOnIo(idx) + case a: Ast.expr.Attribute => + val t = getArrayItemType(translator.detectType(a)) + if (t == KaitaiStreamType || t == OwnedKaitaiStreamType) { + true + } else { + userExprDependsOnIo(a.value) + } + case Ast.expr.Call(func, args) => + (func match { + case Ast.expr.Attribute(value, methodName) => + userExprDependsOnIo(value) + }) || args.exists(v => userExprDependsOnIo(v)) + case Ast.expr.List(values: Seq[Ast.expr]) => + values.exists(v => userExprDependsOnIo(v)) + case Ast.expr.CastToType(value, typeName) => + userExprDependsOnIo(value) + case _: Ast.expr.ByteSizeOfType => false + case _: Ast.expr.BitSizeOfType => false + } + } + + def validDependsOnIo(valid: ValidationSpec): Boolean = { + valid match { + case ValidationEq(expected) => userExprDependsOnIo(expected) + case ValidationMin(min) => userExprDependsOnIo(min) + case ValidationMax(max) => userExprDependsOnIo(max) + case ValidationRange(min, max) => userExprDependsOnIo(min) || userExprDependsOnIo(max) + case ValidationAnyOf(values) => values.exists(v => userExprDependsOnIo(v)) + case ValidationExpr(expr) => userExprDependsOnIo(expr) + } + } + + def attrCheck2(id: Identifier, dataType: DataType, repeat: RepeatSpec, shouldDependOnIo: Option[Boolean], exprTypeOpt: Option[DataType] = None) = { + val item = itemExpr(id, repeat) + dataType match { + case ut: UserType => + val itemUserType = + if (exprTypeOpt.map(exprType => !exprType.isInstanceOf[UserType]).getOrElse(false)) + Ast.expr.CastToType(item, Ast.typeId(true, ut.classSpec.get.name)) + else + item + attrUserTypeCheck(id, itemUserType, ut, shouldDependOnIo) + case t: BytesType => + val itemBytes = + if (exprTypeOpt.map(exprType => !exprType.isInstanceOf[BytesType]).getOrElse(false)) + Ast.expr.CastToType(item, Ast.typeId(false, Seq("bytes"))) + else + item + attrBytesCheck(id, itemBytes, t, shouldDependOnIo) + case st: StrFromBytesType => + val itemStr = + if (exprTypeOpt.map(exprType => !exprType.isInstanceOf[StrType]).getOrElse(false)) + Ast.expr.CastToType(item, Ast.typeId(false, Seq("str"))) + else + item + val bytes = exprStrToBytes(itemStr, st.encoding) + attrBytesCheck(id, bytes, st.bytes, shouldDependOnIo) + case st: SwitchType => + attrSwitchCheck(id, st.on, st.cases, repeat, shouldDependOnIo, st.combinedType) + case _ => // no checks + } + } + + def attrRepeatExprCheck(id: Identifier, expectedSize: Ast.expr, shouldDependOnIo: Option[Boolean]): Unit = { + if (shouldDependOnIo.map(shouldDepend => userExprDependsOnIo(expectedSize) != shouldDepend).getOrElse(false)) + return + attrAssertEqual( + exprArraySize(Ast.expr.InternalName(id)), + expectedSize, + idToMsg(id) + ) + } + + def attrBytesCheck(id: Identifier, bytes: Ast.expr, t: BytesType, shouldDependOnIoOrig: Option[Boolean]): Unit = { + val shouldDependOnIo: Option[Boolean] = + if (t.process.isDefined) { + if (shouldDependOnIoOrig.getOrElse(true)) { + None + } else { + return + } + } else { + shouldDependOnIoOrig + } + + val msgId = idToMsg(id) + val actualSize = exprByteArraySize(bytes) + val canUseNonIoDependent = shouldDependOnIo.map(shouldDepend => shouldDepend == false).getOrElse(true) + t match { + case blt: BytesLimitType => { + val limitSize = blt.size + val canUseLimitSize = shouldDependOnIo.map(shouldDepend => userExprDependsOnIo(limitSize) == shouldDepend).getOrElse(true) + if (canUseLimitSize) { + if (blt.terminator.isDefined || blt.padRight.isDefined) { + // size must be "<= declared" (less than or equal to declared size) + attrAssertLtE(actualSize, limitSize, msgId) + } else { + // size must match declared size exactly + attrAssertEqual( + actualSize, + limitSize, + msgId + ) + } + } + blt.terminator match { + case Some(term) => { + val actualIndexOfTerm = exprByteArrayIndexOf(bytes, term) + val isPadRightActive = blt.padRight.map(padByte => padByte != term).getOrElse(false) + if (!blt.include) { + if (canUseNonIoDependent) { + attrAssertEqual(actualIndexOfTerm, Ast.expr.IntNum(-1), msgId) + } + if (isPadRightActive && canUseLimitSize) { + condIfHeader(Ast.expr.Compare(actualSize, Ast.cmpop.Eq, limitSize)) + // check if the last byte is not `pad-right` + attrBytesPadRightCheck(bytes, actualSize, blt.padRight, msgId) + condIfFooter + } + } else { + val lastByteIndex = Ast.expr.BinOp(actualSize, Ast.operator.Sub, Ast.expr.IntNum(1)) + if (!isPadRightActive && canUseLimitSize) { + condIfHeader(Ast.expr.Compare(actualSize, Ast.cmpop.Lt, limitSize)) + // must not be empty (always contains at least the `terminator` byte) + attrAssertCmp(actualSize, Ast.cmpop.Eq, Ast.expr.IntNum(0), msgId) + // the user wants to terminate the value prematurely and there's no `pad-right` that + // could do that, so the last byte of the value must be `terminator` + attrAssertEqual(actualIndexOfTerm, lastByteIndex, msgId) + condIfFooter + + condIfHeader(Ast.expr.Compare(actualSize, Ast.cmpop.Eq, limitSize)) + attrTermIncludeCheck(actualIndexOfTerm, lastByteIndex, msgId) + condIfFooter + } + if (isPadRightActive && canUseNonIoDependent) { + attrTermIncludeCheck(actualIndexOfTerm, lastByteIndex, msgId) + + condIfHeader(Ast.expr.Compare(actualIndexOfTerm, Ast.cmpop.Eq, Ast.expr.IntNum(-1))) + // check if the last byte is not `pad-right` + attrBytesPadRightCheck(bytes, actualSize, blt.padRight, msgId) + condIfFooter + } + } + } + case None => + if (canUseLimitSize) { + // check if the last byte is not `pad-right` + attrBytesPadRightCheck(bytes, actualSize, blt.padRight, msgId) + } + } + } + case btt: BytesTerminatedType => { + if (canUseNonIoDependent) { + val actualIndexOfTerm = exprByteArrayIndexOf(bytes, btt.terminator) + val lastByteIndex: Ast.expr = Ast.expr.BinOp(actualSize, Ast.operator.Sub, Ast.expr.IntNum(1)) + val expectedIndexOfTerm = if (btt.include) { + if (btt.eosError) { + // must not be empty (always contains at least the `terminator` byte) + attrAssertCmp(actualSize, Ast.cmpop.Eq, Ast.expr.IntNum(0), msgId) + + attrAssertEqual(actualIndexOfTerm, lastByteIndex, msgId) + } else { + attrTermIncludeCheck(actualIndexOfTerm, lastByteIndex, msgId) + } + } else { + attrAssertEqual(actualIndexOfTerm, Ast.expr.IntNum(-1), msgId) + } + } + } + case _ => // no checks + } + } + + def attrBytesPadRightCheck(bytes: Ast.expr, actualSize: Ast.expr, padRight: Option[Int], msgId: String): Unit = + padRight.foreach { (padByte) => + val lastByte = exprByteArrayLast(bytes) + attrBasicCheck( + Ast.expr.BoolOp( + Ast.boolop.And, + Seq( + Ast.expr.Compare(actualSize, Ast.cmpop.NotEq, Ast.expr.IntNum(0)), + Ast.expr.Compare( + lastByte, + Ast.cmpop.Eq, + Ast.expr.IntNum(padByte) + ) + ) + ), + lastByte, + Ast.expr.IntNum(padByte), + msgId + ) + } + + def attrTermIncludeCheck(actualIndexOfTerm: Ast.expr, lastByteIndex: Ast.expr, msgId: String): Unit = + attrBasicCheck( + Ast.expr.BoolOp( + Ast.boolop.And, + Seq( + Ast.expr.Compare(actualIndexOfTerm, Ast.cmpop.NotEq, Ast.expr.IntNum(-1)), + Ast.expr.Compare(actualIndexOfTerm, Ast.cmpop.NotEq, lastByteIndex) + ) + ), + actualIndexOfTerm, + lastByteIndex, + msgId + ) + + def attrUserTypeCheck(id: Identifier, utExpr: Ast.expr, ut: UserType, shouldDependOnIo: Option[Boolean]): Unit = { + /** @note Must be kept in sync with [[JavaCompiler.parseExpr]] */ + if (!ut.isOpaque) { + attrUserTypeParamCheck(id, ut, utExpr, RootIdentifier, CalcKaitaiStructType(), Ast.expr.Name(Ast.identifier(Identifier.ROOT)), shouldDependOnIo) + } + attrParentParamCheck(id, Ast.expr.Attribute(utExpr, Ast.identifier(Identifier.PARENT)), ut, shouldDependOnIo) + (ut.classSpec.get.params, ut.args).zipped.foreach { (paramDef, argExpr) => + attrUserTypeParamCheck(id, ut, utExpr, paramDef.id, paramDef.dataType, argExpr, shouldDependOnIo) + } + } + + def attrUserTypeParamCheck(id: Identifier, ut: UserType, utExpr: Ast.expr, paramId: Identifier, paramDataType: DataType, argExpr: Ast.expr, shouldDependOnIo: Option[Boolean]): Unit = { + val paramItemType = getArrayItemType(paramDataType) + val paramBasedOnIo = (paramItemType == KaitaiStreamType || paramItemType == OwnedKaitaiStreamType) + // parameters with types `io` or `io[]` never have to be checked for consistency because they're set by the generated code + if (paramBasedOnIo) + return + if (shouldDependOnIo.map(shouldDepend => userExprDependsOnIo(argExpr) != shouldDepend).getOrElse(false)) + return + val paramAttrName = paramId match { + case NamedIdentifier(name) => name + case SpecialIdentifier(name) => name + } + val actualArgExpr = Ast.expr.Attribute(utExpr, Ast.identifier(paramAttrName)) + val msgId = idToMsg(id) + paramDataType match { + /** @note Must be kept in sync with [[translators.BaseTranslator.translate]] */ + case _: NumericType | _: BooleanType | _: StrType | _: BytesType | _: EnumType => + attrAssertEqual(actualArgExpr, argExpr, msgId) + case _: ArrayType => + attrObjectsEqualCheck(actualArgExpr, argExpr, msgId) + case _: StructType => + attrObjectsEqualCheck(actualArgExpr, argExpr, msgId) + case AnyType => + attrObjectsEqualCheck(actualArgExpr, argExpr, msgId) + } + } + + def attrSwitchCheck( + id: Identifier, + on: Ast.expr, + cases: Map[Ast.expr, DataType], + rep: RepeatSpec, + shouldDependOnIoOrig: Option[Boolean], + assignType: DataType + ): Unit = { + val shouldDependOnIo: Option[Boolean] = + if (userExprDependsOnIo(on)) { + if (shouldDependOnIoOrig.getOrElse(true)) { + None + } else { + return + } + } else { + shouldDependOnIoOrig + } + + switchCases[DataType](id, on, cases, + (dataType) => { + attrCheck2(id, dataType, rep, shouldDependOnIo, Some(assignType)) + }, + (dataType) => if (switchBytesOnlyAsRaw) { + dataType match { + case t: BytesType => + attrCheck2(RawIdentifier(id), dataType, rep, shouldDependOnIo, Some(assignType)) + case _ => + attrCheck2(id, dataType, rep, shouldDependOnIo, Some(assignType)) + } + } else { + attrCheck2(id, dataType, rep, shouldDependOnIo, Some(assignType)) + } + ) + } + + def attrBasicCheck(checkExpr: Ast.expr, actual: Ast.expr, expected: Ast.expr, msg: String): Unit + + // This may turn out to be too Java-specific method, so we can refactor it later. + def attrObjectsEqualCheck(actual: Ast.expr, expected: Ast.expr, msg: String): Unit + + private + def idToMsg(id: Identifier): String = id.humanReadable + + def exprByteArraySize(name: Ast.expr) = + Ast.expr.Attribute( + name, + Ast.identifier("size") + ) + + def exprByteArrayLast(name: Ast.expr) = + Ast.expr.Attribute( + name, + Ast.identifier("last") + ) + + def exprByteArrayIndexOf(name: Ast.expr, term: Int) = + Ast.expr.Call( + Ast.expr.Attribute( + name, + Ast.identifier("index_of") + ), + Seq(Ast.expr.IntNum(term)) + ) + + def exprStrToBytes(name: Ast.expr, encoding: String) = + Ast.expr.Call( + Ast.expr.Attribute( + name, + Ast.identifier("to_b") + ), + Seq(Ast.expr.Str(encoding)) + ) + + def exprArraySize(name: Ast.expr) = exprByteArraySize(name) + + def exprArrayLast(name: Ast.expr) = exprByteArrayLast(name) + + def attrAssertUntilNotEmpty(id: Identifier): Unit = { + // the array must not be empty (always contains at least the `repeat-until: {true}` element) + attrAssertCmp(exprArraySize(Ast.expr.InternalName(id)), Ast.cmpop.Eq, Ast.expr.IntNum(0), idToMsg(id)) + } + + def attrAssertUntilCond(id: Identifier, dataType: DataType, repUntil: RepeatUntil, shouldDependOnIo: Option[Boolean]): Unit = { + typeProvider._currentIteratorType = Some(dataType) + if (shouldDependOnIo.map(shouldDepend => userExprDependsOnIo(repUntil.expr) != shouldDepend).getOrElse(false)) + return + blockScopeHeader + handleAssignmentTempVar( + dataType, + translator.doName(Identifier.ITERATOR), + translator.translate(itemExpr(id, repUntil)) + ) + attrAssertEqual( + repUntil.expr, + Ast.expr.Compare( + Ast.expr.Name(Ast.identifier(Identifier.INDEX)), + Ast.cmpop.Eq, + Ast.expr.BinOp(exprArraySize(Ast.expr.InternalName(id)), Ast.operator.Sub, Ast.expr.IntNum(1)) + ), + idToMsg(id) + ) + blockScopeFooter + } + + def attrIsEofCheck(id: Identifier, expectedIsEof: Boolean, io: String): Unit = + attrIsEofCheck(io, expectedIsEof, idToMsg(id)) + + def attrIsEofCheck(io: String, expectedIsEof: Boolean, msg: String): Unit + + def attrParentParamCheck(id: Identifier, actualParentExpr: Ast.expr, ut: UserType, shouldDependOnIo: Option[Boolean]): Unit = + attrParentParamCheck(actualParentExpr, ut, shouldDependOnIo, idToMsg(id)) + + def attrParentParamCheck(actualParentExpr: Ast.expr, ut: UserType, shouldDependOnIo: Option[Boolean], msg: String): Unit + + def attrAssertEqual(actual: Ast.expr, expected: Ast.expr, msg: String): Unit = + attrAssertCmp(actual, Ast.cmpop.NotEq, expected, msg) + + def attrAssertLtE(actual: Ast.expr, expected: Ast.expr, msg: String): Unit = + attrAssertCmp(actual, Ast.cmpop.Gt, expected, msg) + + def attrAssertCmp(actual: Ast.expr, op: Ast.cmpop, expected: Ast.expr, msg: String): Unit = + attrBasicCheck( + Ast.expr.Compare(actual, op, expected), + actual, + expected, + msg + ) +} diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala index 76fc21a3c..18368ad88 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/LanguageCompiler.scala @@ -116,6 +116,7 @@ abstract class LanguageCompiler( def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit def attributeReader(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit + def attributeSetter(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = ??? def attributeDoc(id: Identifier, doc: DocSpec): Unit = {} def attrParse(attr: AttrLikeSpec, id: Identifier, defEndian: Option[Endianness]): Unit @@ -123,13 +124,29 @@ abstract class LanguageCompiler( def attrInit(attr: AttrLikeSpec): Unit = {} def attrDestructor(attr: AttrLikeSpec, id: Identifier): Unit = {} - // TODO: delete - def attrFixedContentsParse(attrName: Identifier, contents: Array[Byte]): Unit + def attrFetchInstances(attr: AttrLikeSpec, id: Identifier): Unit = {} + def fetchInstancesHeader(): Unit = {} + def fetchInstancesFooter(): Unit = {} + def attrInvokeFetchInstances(baseExpr: Ast.expr, exprType: DataType, dataType: DataType): Unit = ??? + def attrInvokeInstance(instName: InstanceIdentifier): Unit = ??? + + def writeHeader(endian: Option[FixedEndian], isEmpty: Boolean): Unit = ??? + def writeFooter(): Unit = ??? + def writeInstanceHeader(instName: InstanceIdentifier): Unit = ??? + def writeInstanceFooter(): Unit = ??? + def attrWrite(attr: AttrLikeSpec, id: Identifier, defEndian: Option[Endianness]): Unit = ??? + def runWriteCalc(): Unit = ??? + + def checkHeader(): Unit = ??? + def checkFooter(): Unit = ??? + def checkInstanceHeader(instName: InstanceIdentifier): Unit = ??? + def checkInstanceFooter(): Unit = ??? + def attrCheck(attr: AttrLikeSpec, id: Identifier): Unit = ??? def condIfSetNull(instName: Identifier): Unit = {} def condIfSetNonNull(instName: Identifier): Unit = {} def condIfHeader(expr: Ast.expr): Unit - def condIfFooter(expr: Ast.expr): Unit + def condIfFooter: Unit def condRepeatInitAttr(id: Identifier, dataType: DataType): Unit @@ -142,24 +159,43 @@ abstract class LanguageCompiler( def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, untilExpr: Ast.expr): Unit def condRepeatUntilFooter(id: Identifier, io: String, dataType: DataType, untilExpr: Ast.expr): Unit + def condRepeatCommonHeader(id: Identifier, io: String, dataType: DataType): Unit = {} + def condRepeatCommonFooter: Unit = {} + def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier, rep: RepeatSpec): Unit def normalIO: String def useIO(ioEx: Ast.expr): String def pushPos(io: String): Unit + def pushPosForSubIOWriteBackHandler(io: String): Unit = ??? def seek(io: String, pos: Ast.expr): Unit + def seekRelative(io: String, relPos: String): Unit = ??? def popPos(io: String): Unit def alignToByte(io: String): Unit + def exprIORemainingSize(io: String): String = ??? + + def subIOWriteBackHeader(subIO: String, process: Option[ProcessExpr]): String = ??? + def subIOWriteBackFooter(subIO: String): Unit = ??? + + def addChildIO(io: String, childIO: String): Unit = ??? + def instanceDeclHeader(className: List[String]): Unit = {} def instanceClear(instName: InstanceIdentifier): Unit = {} def instanceSetCalculated(instName: InstanceIdentifier): Unit = {} def instanceDeclaration(attrName: InstanceIdentifier, attrType: DataType, isNullable: Boolean): Unit = attributeDeclaration(attrName, attrType, isNullable) + def instanceWriteFlagDeclaration(attrName: InstanceIdentifier): Unit = ??? + def instanceWriteFlagInit(attrName: InstanceIdentifier): Unit = {} + def instanceSetWriteFlag(instName: InstanceIdentifier): Unit = ??? + def instanceClearWriteFlag(instName: InstanceIdentifier): Unit = ??? + def instanceToWriteSetter(instName: InstanceIdentifier): Unit = ??? def instanceHeader(className: List[String], instName: InstanceIdentifier, dataType: DataType, isNullable: Boolean): Unit def instanceFooter: Unit def instanceCheckCacheAndReturn(instName: InstanceIdentifier, dataType: DataType): Unit def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit def instanceCalculate(instName: Identifier, dataType: DataType, value: Ast.expr) + def instanceInvalidate(instName: InstanceIdentifier): Unit = ??? + def instanceCheckWriteFlagAndWrite(instName: InstanceIdentifier): Unit = ??? def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit @@ -192,7 +228,7 @@ abstract class LanguageCompiler( def attrParseIfFooter(ifExpr: Option[Ast.expr]): Unit = { ifExpr match { - case Some(e) => condIfFooter(e) + case Some(e) => condIfFooter case None => // ignore } } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/UniversalFooter.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/UniversalFooter.scala index bbfd392b8..2beb55b67 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/UniversalFooter.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/UniversalFooter.scala @@ -1,12 +1,13 @@ package io.kaitai.struct.languages.components import io.kaitai.struct.exprlang.Ast.expr +import io.kaitai.struct.format.ClassSpec /** * All footers in the language look the same and can be written by the same * simple argument-less method. */ -trait UniversalFooter { +trait UniversalFooter extends LanguageCompiler { /** * Single method that outputs all kind of footers in the language. */ @@ -14,8 +15,14 @@ trait UniversalFooter { def classFooter(name: String): Unit = universalFooter def classConstructorFooter: Unit = universalFooter + override def readFooter: Unit = universalFooter + override def writeFooter: Unit = universalFooter + override def writeInstanceFooter: Unit = universalFooter + override def checkFooter: Unit = universalFooter + override def checkInstanceFooter: Unit = universalFooter def condRepeatExprFooter = universalFooter def condRepeatEosFooter: Unit = universalFooter - def condIfFooter(expr: expr): Unit = universalFooter + override def condRepeatCommonFooter: Unit = universalFooter + def condIfFooter: Unit = universalFooter def instanceFooter: Unit = universalFooter } diff --git a/shared/src/main/scala/io/kaitai/struct/languages/components/ValidateOps.scala b/shared/src/main/scala/io/kaitai/struct/languages/components/ValidateOps.scala index 39290579c..3006a6e42 100644 --- a/shared/src/main/scala/io/kaitai/struct/languages/components/ValidateOps.scala +++ b/shared/src/main/scala/io/kaitai/struct/languages/components/ValidateOps.scala @@ -13,17 +13,17 @@ trait ValidateOps extends ExceptionNames { val translator: AbstractTranslator val typeProvider: ClassTypeProvider - def attrValidate(attrId: Identifier, attr: AttrLikeSpec, valid: ValidationSpec): Unit = { + def attrValidate(attrId: Identifier, attr: AttrLikeSpec, valid: ValidationSpec, useIo: Boolean): Unit = { valid match { case ValidationEq(expected) => - attrValidateExprCompare(attrId, attr, Ast.cmpop.Eq, expected, ValidationNotEqualError(attr.dataTypeComposite)) + attrValidateExprCompare(attrId, attr, Ast.cmpop.Eq, expected, ValidationNotEqualError(attr.dataTypeComposite), useIo) case ValidationMin(min) => - attrValidateExprCompare(attrId, attr, Ast.cmpop.GtE, min, ValidationLessThanError(attr.dataTypeComposite)) + attrValidateExprCompare(attrId, attr, Ast.cmpop.GtE, min, ValidationLessThanError(attr.dataTypeComposite), useIo) case ValidationMax(max) => - attrValidateExprCompare(attrId, attr, Ast.cmpop.LtE, max, ValidationGreaterThanError(attr.dataTypeComposite)) + attrValidateExprCompare(attrId, attr, Ast.cmpop.LtE, max, ValidationGreaterThanError(attr.dataTypeComposite), useIo) case ValidationRange(min, max) => - attrValidateExprCompare(attrId, attr, Ast.cmpop.GtE, min, ValidationLessThanError(attr.dataTypeComposite)) - attrValidateExprCompare(attrId, attr, Ast.cmpop.LtE, max, ValidationGreaterThanError(attr.dataTypeComposite)) + attrValidateExprCompare(attrId, attr, Ast.cmpop.GtE, min, ValidationLessThanError(attr.dataTypeComposite), useIo) + attrValidateExprCompare(attrId, attr, Ast.cmpop.LtE, max, ValidationGreaterThanError(attr.dataTypeComposite), useIo) case ValidationAnyOf(values) => val bigOrExpr = Ast.expr.BoolOp( Ast.boolop.Or, @@ -37,15 +37,10 @@ trait ValidateOps extends ExceptionNames { ) attrValidateExpr( - attrId, - attr.dataTypeComposite, + attr, checkExpr = bigOrExpr, err = ValidationNotAnyOfError(attr.dataTypeComposite), - errArgs = List( - Ast.expr.InternalName(attrId), - Ast.expr.InternalName(IoIdentifier), - Ast.expr.Str(attr.path.mkString("/", "/", "")) - ) + useIo ) case ValidationExpr(expr) => blockScopeHeader @@ -56,40 +51,37 @@ trait ValidateOps extends ExceptionNames { translator.translate(Ast.expr.InternalName(attrId)) ) attrValidateExpr( - attrId, - attr.dataTypeComposite, + attr, expr, ValidationExprError(attr.dataTypeComposite), - List( - Ast.expr.InternalName(attrId), - Ast.expr.InternalName(IoIdentifier), - Ast.expr.Str(attr.path.mkString("/", "/", "")) - ) + useIo ) blockScopeFooter } } - def attrValidateExprCompare(attrId: Identifier, attr: AttrLikeSpec, op: Ast.cmpop, expected: Ast.expr, err: KSError): Unit = { + def attrValidateExprCompare( + attrId: Identifier, + attr: AttrLikeSpec, + op: Ast.cmpop, + expected: Ast.expr, + err: KSError, + useIo: Boolean + ): Unit = { attrValidateExpr( - attrId, - attr.dataTypeComposite, + attr, checkExpr = Ast.expr.Compare( Ast.expr.InternalName(attrId), op, expected ), err = err, - errArgs = List( - expected, - Ast.expr.InternalName(attrId), - Ast.expr.InternalName(IoIdentifier), - Ast.expr.Str(attr.path.mkString("/", "/", "")) - ) + useIo = useIo, + expected = Some(expected) ) } - def attrValidateExpr(attrId: Identifier, attrType: DataType, checkExpr: Ast.expr, err: KSError, errArgs: List[Ast.expr]): Unit = {} + def attrValidateExpr(attr: AttrLikeSpec, checkExpr: Ast.expr, err: KSError, useIo: Boolean, expected: Option[Ast.expr] = None): Unit = {} def handleAssignmentTempVar(dataType: DataType, id: String, expr: String): Unit def blockScopeHeader: Unit def blockScopeFooter: Unit diff --git a/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala b/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala index 6312a67d9..9e3af9fed 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/CommonMethods.scala @@ -97,6 +97,7 @@ abstract trait CommonMethods[T] extends TypeDetector { // TODO: check argument quantity case (_: StrType, "substring") => strSubstring(obj, args(0), args(1)) case (_: StrType, "to_i") => strToInt(obj, args(0)) + case (_: StrType, "to_b") => strToBytes(obj, args(0)) case (_: BytesType, "to_s") => args match { case Seq(Ast.expr.Str(encoding)) => @@ -106,6 +107,7 @@ abstract trait CommonMethods[T] extends TypeDetector { case _ => throw new TypeMismatchError(s"to_s: expected 1 argument, got ${args.length}") } + case (_: BytesType, "index_of") => bytesIndexOf(obj, args(0)) case _ => throw new TypeMismatchError(s"don't know how to call method '$methodName' of object type '$objType'") } } @@ -125,8 +127,10 @@ abstract trait CommonMethods[T] extends TypeDetector { def strReverse(s: Ast.expr): T def strToInt(s: Ast.expr, base: Ast.expr): T def strSubstring(s: Ast.expr, from: Ast.expr, to: Ast.expr): T + def strToBytes(s: Ast.expr, encoding: Ast.expr): T = ??? def bytesToStr(value: Ast.expr, encoding: String): T + def bytesIndexOf(value: Ast.expr, expr: Ast.expr): T = ??? def intToStr(value: Ast.expr, num: Ast.expr): T diff --git a/shared/src/main/scala/io/kaitai/struct/translators/JavaTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/JavaTranslator.scala index 5130a069f..8ccaf6781 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/JavaTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/JavaTranslator.scala @@ -1,6 +1,6 @@ package io.kaitai.struct.translators -import io.kaitai.struct.{ImportList, Utils} +import io.kaitai.struct.{ClassTypeProvider, ImportList, RuntimeConfig, Utils} import io.kaitai.struct.exprlang.Ast import io.kaitai.struct.exprlang.Ast._ import io.kaitai.struct.datatype.DataType @@ -8,7 +8,7 @@ import io.kaitai.struct.datatype.DataType._ import io.kaitai.struct.format.Identifier import io.kaitai.struct.languages.JavaCompiler -class JavaTranslator(provider: TypeProvider, importList: ImportList) extends BaseTranslator(provider) { +class JavaTranslator(provider: TypeProvider, importList: ImportList, config: RuntimeConfig) extends BaseTranslator(provider) { override def doIntLiteral(n: BigInt): String = { // Java's integer parsing behaves differently depending on whether you use decimal or hex syntax. // With decimal syntax, the parser/compiler rejects any number that cannot be stored in a long @@ -32,7 +32,10 @@ class JavaTranslator(provider: TypeProvider, importList: ImportList) extends Bas } override def doArrayLiteral(t: DataType, value: Seq[expr]): String = { - val javaType = JavaCompiler.kaitaiType2JavaTypeBoxed(t) + // FIXME + val compiler = new JavaCompiler(provider.asInstanceOf[ClassTypeProvider], config) + + val javaType = compiler.kaitaiType2JavaTypeBoxed(t) val commaStr = value.map((v) => translate(v)).mkString(", ") importList.add("java.util.ArrayList") @@ -54,6 +57,12 @@ class JavaTranslator(provider: TypeProvider, importList: ImportList) extends Bas } } + override def doNumericCompareOp(left: expr, op: cmpop, right: expr): String = + s"(${super.doNumericCompareOp(left, op, right)})" + + override def doEnumCompareOp(left: expr, op: cmpop, right: expr): String = + s"(${super.doEnumCompareOp(left, op, right)})" + override def doName(s: String) = s match { case Identifier.ITERATOR => "_it" @@ -99,11 +108,19 @@ class JavaTranslator(provider: TypeProvider, importList: ImportList) extends Bas } override def arraySubscript(container: expr, idx: expr): String = - s"${translate(container)}.get((int) ${translate(idx)})" + s"${translate(container)}.get(${doCast(idx, CalcIntType)})" override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String = s"(${translate(condition)} ? ${translate(ifTrue)} : ${translate(ifFalse)})" - override def doCast(value: Ast.expr, typeName: DataType): String = - s"((${JavaCompiler.kaitaiType2JavaType(typeName)}) (${translate(value)}))" + override def doCast(value: Ast.expr, typeName: DataType): String = { + // FIXME + val compiler = new JavaCompiler(provider.asInstanceOf[ClassTypeProvider], config) + if (value.isInstanceOf[Ast.expr.IntNum] || value.isInstanceOf[Ast.expr.FloatNum]) + // this branch is not really needed, but makes the code a bit cleaner - + // we can simplify casting to just this for numeric constants + s"((${compiler.kaitaiType2JavaType(typeName)}) ${translate(value)})" + else + compiler.castIfNeeded(translate(value), AnyType, typeName) + } // Predefined methods of various types override def strToInt(s: expr, base: expr): String = @@ -111,7 +128,7 @@ class JavaTranslator(provider: TypeProvider, importList: ImportList) extends Bas override def enumToInt(v: expr, et: EnumType): String = s"${translate(v)}.id()" override def floatToInt(v: expr): String = - s"(int) (${translate(v)} + 0)" + doCast(v, CalcIntType) override def intToStr(i: expr, base: expr): String = s"Long.toString(${translate(i)}, ${translate(base)})" override def bytesToStr(bytesExpr: String, encoding: String): String = { @@ -136,26 +153,39 @@ class JavaTranslator(provider: TypeProvider, importList: ImportList) extends Bas } s"new String($bytesExpr, $charsetExpr)" } + override def bytesIndexOf(b: expr, byte: expr): String = + s"${JavaCompiler.kstreamName}.byteArrayIndexOf(${translate(b)}, ${doCast(byte, Int1Type(true))})" override def bytesLength(b: Ast.expr): String = s"${translate(b)}.length" override def bytesSubscript(container: Ast.expr, idx: Ast.expr): String = - s"${translate(container)}[${translate(idx)}]" + s"(${translate(container)}[${doCast(idx, CalcIntType)}] & 0xff)" override def bytesFirst(b: Ast.expr): String = - s"${translate(b)}[0]" + bytesSubscript(b, Ast.expr.IntNum(0)) override def bytesLast(b: Ast.expr): String = - s"${translate(b)}[(${translate(b)}).length - 1]" + bytesSubscript(b, Ast.expr.BinOp( + Ast.expr.Attribute( + b, + Ast.identifier("length") + ), + Ast.operator.Sub, + Ast.expr.IntNum(1) + )) override def bytesMin(b: Ast.expr): String = s"${JavaCompiler.kstreamName}.byteArrayMin(${translate(b)})" override def bytesMax(b: Ast.expr): String = s"${JavaCompiler.kstreamName}.byteArrayMax(${translate(b)})" override def strLength(s: expr): String = - s"${translate(s)}.length()" + s"(${translate(s)}).length()" override def strReverse(s: expr): String = s"new StringBuilder(${translate(s)}).reverse().toString()" override def strSubstring(s: expr, from: expr, to: expr): String = - s"${translate(s)}.substring(${translate(from)}, ${translate(to)})" + s"(${translate(s)}).substring(${translate(from)}, ${translate(to)})" + override def strToBytes(s: expr, encoding: expr): String = { + importList.add("java.nio.charset.Charset") + s"(${translate(s)}).getBytes(Charset.forName(${translate(encoding)}))" + } override def arrayFirst(a: expr): String = s"${translate(a)}.get(0)" diff --git a/shared/src/main/scala/io/kaitai/struct/translators/PythonTranslator.scala b/shared/src/main/scala/io/kaitai/struct/translators/PythonTranslator.scala index 2a4e21f81..0991ab0d2 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/PythonTranslator.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/PythonTranslator.scala @@ -16,6 +16,18 @@ class PythonTranslator(provider: TypeProvider, importList: ImportList) extends B } } + override def doNumericCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = + s"(${super.doNumericCompareOp(left, op, right)})" + + override def doStrCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = + s"(${super.doStrCompareOp(left, op, right)})" + + override def doEnumCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = + s"(${super.doEnumCompareOp(left, op, right)})" + + override def doBytesCompareOp(left: Ast.expr, op: Ast.cmpop, right: Ast.expr): String = + s"(${super.doBytesCompareOp(left, op, right)})" + override def doStringLiteral(s: String): String = "u" + super.doStringLiteral(s) override def doBoolLiteral(n: Boolean): String = if (n) "True" else "False" @@ -104,6 +116,8 @@ class PythonTranslator(provider: TypeProvider, importList: ImportList) extends B } override def bytesToStr(bytesExpr: String, encoding: String): String = s"""($bytesExpr).decode("$encoding")""" + override def bytesIndexOf(b: Ast.expr, byte: Ast.expr): String = + s"${PythonCompiler.kstreamName}.byte_array_index_of(${translate(b)}, ${translate(byte)})" override def bytesLength(value: Ast.expr): String = s"len(${translate(value)})" @@ -125,6 +139,8 @@ class PythonTranslator(provider: TypeProvider, importList: ImportList) extends B s"(${translate(value)})[::-1]" override def strSubstring(s: Ast.expr, from: Ast.expr, to: Ast.expr): String = s"(${translate(s)})[${translate(from)}:${translate(to)}]" + override def strToBytes(s: Ast.expr, encoding: Ast.expr): String = + s"(${translate(s)}).encode(${translate(encoding)})" override def arrayFirst(a: Ast.expr): String = s"${translate(a)}[0]" diff --git a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala index 5b017772c..03452bcaa 100644 --- a/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala +++ b/shared/src/main/scala/io/kaitai/struct/translators/TypeDetector.scala @@ -249,7 +249,9 @@ class TypeDetector(provider: TypeProvider) { (objType, methodName.name) match { case (_: StrType, "substring") => CalcStrType case (_: StrType, "to_i") => CalcIntType + case (_: StrType, "to_b") => CalcBytesType case (_: BytesType, "to_s") => CalcStrType + case (_: BytesType, "index_of") => CalcIntType case _ => throw new MethodNotFoundError(methodName.name, objType) }