diff --git a/README.md b/README.md index dcb9ead..ead75db 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ **bort** is a cross-platform Small-C language cross compiler for RISC-V architecture. ## Global TODOs -- goto - switch - compound variable declaration (`int a = 5, b, c = d;`) - testing diff --git a/dev.sh b/dev.sh index 7617f3d..5ee5eaa 100755 --- a/dev.sh +++ b/dev.sh @@ -47,5 +47,5 @@ cp -f build/compile_commands.json . if [ $# -ne 0 ] && [ "$1" == "run" ]; then echo -e "----------------------------------\n" set -eux - ./build/bort --dump-ast --emit-ir --dump-codegen-info -o - ./tests/corpus/globals.c + ./build/bort --dump-ast --emit-ir --dump-codegen-info -o - ./tests/corpus/goto.c fi diff --git a/include/bort/AST/ASTNode.hpp b/include/bort/AST/ASTNode.hpp index 4950a8a..805a89f 100644 --- a/include/bort/AST/ASTNode.hpp +++ b/include/bort/AST/ASTNode.hpp @@ -34,6 +34,8 @@ enum class NodeKind { ReturnStmt, BreakStmt, ContinueStmt, + LabelStmt, + GotoStmt, ASTRoot, NUM_NODES }; diff --git a/include/bort/AST/GotoStmt.hpp b/include/bort/AST/GotoStmt.hpp new file mode 100644 index 0000000..ae21285 --- /dev/null +++ b/include/bort/AST/GotoStmt.hpp @@ -0,0 +1,24 @@ +#pragma once +#include "bort/AST/ASTNode.hpp" + +namespace bort::ast { + +class GotoStmt final : public Statement { +private: + explicit GotoStmt(std::string targetLabelName) + : Statement{ NodeKind::GotoStmt }, + m_TargetLabel{ std::move(targetLabelName) } { + } + +public: + [[nodiscard]] auto getTargetLabel() const -> std::string { + return m_TargetLabel; + } + + friend class ASTRoot; + +private: + std::string m_TargetLabel; +}; + +} // namespace bort::ast diff --git a/include/bort/AST/LabelStmt.hpp b/include/bort/AST/LabelStmt.hpp new file mode 100644 index 0000000..ce173e7 --- /dev/null +++ b/include/bort/AST/LabelStmt.hpp @@ -0,0 +1,24 @@ +#pragma once +#include "bort/AST/ASTNode.hpp" + +namespace bort::ast { + +class LabelStmt final : public Statement { +private: + explicit LabelStmt(std::string labelName) + : Statement{ NodeKind::LabelStmt }, + m_LabelName{ std::move(labelName) } { + } + +public: + [[nodiscard]] auto getLabelName() const -> std::string { + return m_LabelName; + } + + friend class ASTRoot; + +private: + std::string m_LabelName; +}; + +} // namespace bort::ast diff --git a/include/bort/AST/Visitors/ASTPrinter.hpp b/include/bort/AST/Visitors/ASTPrinter.hpp index d35936b..a15eee3 100644 --- a/include/bort/AST/Visitors/ASTPrinter.hpp +++ b/include/bort/AST/Visitors/ASTPrinter.hpp @@ -1,8 +1,10 @@ #pragma once #include "bort/AST/BreakStmt.hpp" #include "bort/AST/ContinueStmt.hpp" +#include "bort/AST/GotoStmt.hpp" #include "bort/AST/IndexationExpr.hpp" #include "bort/AST/InitializerList.hpp" +#include "bort/AST/LabelStmt.hpp" #include "bort/AST/UnaryOpExpr.hpp" #include "bort/AST/Visitors/ASTVisitor.hpp" #include @@ -23,6 +25,7 @@ class ASTPrinter : public StructureAwareASTVisitor { void visit(const Ref& varNode) override; void visit(const Ref& strNode) override; void visit(const Ref& charNode) override; + void visit(const Ref& functionCallExpr) override; void visit(const Ref& varDeclNode) override; void visit(const Ref& initializerListNode) override; void visit(const Ref& indexationExpr) override; @@ -36,7 +39,8 @@ class ASTPrinter : public StructureAwareASTVisitor { void visit(const Ref& returnStmtNode) override; void visit(const Ref& breakStmtNode) override; void visit(const Ref& continueStmtNode) override; - void visit(const Ref& functionCallExpr) override; + void visit(const Ref& labelStmtNode) override; + void visit(const Ref& gotoStmtNode) override; void push(); void pop(); diff --git a/include/bort/AST/Visitors/ASTVisitor.hpp b/include/bort/AST/Visitors/ASTVisitor.hpp index 8c2eb5a..a2b2e19 100644 --- a/include/bort/AST/Visitors/ASTVisitor.hpp +++ b/include/bort/AST/Visitors/ASTVisitor.hpp @@ -25,6 +25,8 @@ class WhileStmt; class ReturnStmt; class BreakStmt; class ContinueStmt; +class LabelStmt; +class GotoStmt; class ASTVisitorBase { public: @@ -88,6 +90,12 @@ class StructureAwareASTVisitor : public ASTVisitorBase { virtual void visit(const Ref& /* continueNode */) { // leaf } + virtual void visit(const Ref& /* labelNode */) { + // leaf + } + virtual void visit(const Ref& /* gotoNode */) { + // leaf + } virtual void visit(const Ref& varDeclNode); virtual void visit(const Ref& initializerListNode); virtual void visit(const Ref& indexationExpr); diff --git a/include/bort/AST/Visitors/SymbolResolutionVisitor.hpp b/include/bort/AST/Visitors/SymbolResolutionVisitor.hpp index 7c363ec..ee211cc 100644 --- a/include/bort/AST/Visitors/SymbolResolutionVisitor.hpp +++ b/include/bort/AST/Visitors/SymbolResolutionVisitor.hpp @@ -1,12 +1,14 @@ #pragma once #include "bort/AST/FunctionCallExpr.hpp" #include "bort/AST/FunctionDecl.hpp" +#include "bort/AST/GotoStmt.hpp" #include "bort/AST/Visitors/ASTVisitor.hpp" #include "bort/Basic/Ref.hpp" #include "bort/Frontend/Symbol.hpp" #include #include #include +#include namespace bort::ast { @@ -37,11 +39,13 @@ class SymbolResolutionVisitor final : public StructureAwareASTVisitor { SymbolResolutionVisitor(); protected: + void visit(const Ref& astRoot) override; void visit(const Ref& varNode) override; void visit(const Ref& varDeclNode) override; void visit(const Ref& blockNode) override; void visit(const Ref& functionDeclNode) override; void visit(const Ref& functionCallExpr) override; + void visit(const Ref& gotoStmt) override; private: void push(); @@ -51,5 +55,6 @@ class SymbolResolutionVisitor final : public StructureAwareASTVisitor { [[nodiscard]] auto resolve(const std::string& name) -> Ref; Ref m_CurrentScope; + std::unordered_set m_DefinedLabels; }; } // namespace bort::ast diff --git a/include/bort/AST/Visitors/Utils.hpp b/include/bort/AST/Visitors/Utils.hpp index d42dd57..bb89978 100644 --- a/include/bort/AST/Visitors/Utils.hpp +++ b/include/bort/AST/Visitors/Utils.hpp @@ -7,9 +7,11 @@ #include "bort/AST/ExpressionStmt.hpp" #include "bort/AST/FunctionCallExpr.hpp" #include "bort/AST/FunctionDecl.hpp" +#include "bort/AST/GotoStmt.hpp" #include "bort/AST/IfStmt.hpp" #include "bort/AST/IndexationExpr.hpp" #include "bort/AST/InitializerList.hpp" +#include "bort/AST/LabelStmt.hpp" #include "bort/AST/NumberExpr.hpp" #include "bort/AST/ReturnStmt.hpp" #include "bort/AST/UnaryOpExpr.hpp" @@ -19,6 +21,12 @@ #include "bort/Basic/Casts.hpp" namespace bort::ast { +template +constexpr static auto castVisit(const Ref& node, F&& visit) { + bort_assert_nomsg(dynCastRef((node))); + return visit(dynCastRef((node))); +} + /// \brief Driver for all AST walkers /// /// It is used in StructureAwareASTVisitor which is aimed on more or less @@ -30,73 +38,43 @@ template auto callHandler(const Ref& node, F&& visit) { switch (node->getKind()) { case NodeKind::NumberExpr: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::VariableExpr: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::BinOpExpr: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::UnaryOpExpr: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::VarDecl: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::InitializerList: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::IndexationExpr: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::FunctionDecl: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::FunctionCallExpr: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::ExpressionStmt: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::Block: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::IfStmt: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::WhileStmt: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::ReturnStmt: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::BreakStmt: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); case NodeKind::ContinueStmt: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); + case NodeKind::LabelStmt: + return castVisit(node, visit); + case NodeKind::GotoStmt: + return castVisit(node, visit); case NodeKind::ASTRoot: - bort_assert_nomsg(dynCastRef(node)); - return visit(dynCastRef(node)); - break; + return castVisit(node, visit); default: bort_assert(false, "Generic visit not implemented for node"); // unreachable, casting to root just for fun diff --git a/include/bort/IR/IRCodegen.hpp b/include/bort/IR/IRCodegen.hpp index 543ca91..3e7c0bd 100644 --- a/include/bort/IR/IRCodegen.hpp +++ b/include/bort/IR/IRCodegen.hpp @@ -4,9 +4,11 @@ #include "bort/AST/BreakStmt.hpp" #include "bort/AST/ContinueStmt.hpp" #include "bort/AST/FunctionCallExpr.hpp" +#include "bort/AST/GotoStmt.hpp" #include "bort/AST/IfStmt.hpp" #include "bort/AST/IndexationExpr.hpp" #include "bort/AST/InitializerList.hpp" +#include "bort/AST/LabelStmt.hpp" #include "bort/AST/NumberExpr.hpp" #include "bort/AST/ReturnStmt.hpp" #include "bort/AST/UnaryOpExpr.hpp" @@ -22,6 +24,8 @@ namespace bort::ir { +struct GotoUnresolvedLabelMD; + class IRCodegen { public: [[nodiscard]] auto takeInstructions() -> Module&& { @@ -38,6 +42,7 @@ class IRCodegen { auto visit(const Ref& unaryOpExpr) -> ValueRef; auto visit(const Ref& numNode) -> ValueRef; auto visit(const Ref& varNode) -> ValueRef; + auto visit(const Ref& funcCallExpr) -> ValueRef; auto visit(const Ref& varDeclNode) -> ValueRef; auto visit(const Ref& initializerListNode) -> ValueRef; @@ -51,13 +56,16 @@ class IRCodegen { auto visit(const Ref& returnStmt) -> ValueRef; auto visit(const Ref& breakStmt) -> ValueRef; auto visit(const Ref& continueStmt) -> ValueRef; - auto visit(const Ref& funcCallExpr) -> ValueRef; + auto visit(const Ref& labelStmt) -> ValueRef; + auto visit(const Ref& gotoStmt) -> ValueRef; + void processGlobalVarDecl(const Ref& varDeclNode); auto genBranchFromCondition(const Ref& cond, bool negate = false) -> Ref; auto genArrayPtr(const ValueRef& arr) -> std::pair, ValueRef>; + void resolveGotoLabels(); template requires std::is_base_of_v diff --git a/include/bort/IR/Module.hpp b/include/bort/IR/Module.hpp index 184fb97..1ac604c 100644 --- a/include/bort/IR/Module.hpp +++ b/include/bort/IR/Module.hpp @@ -99,6 +99,12 @@ class Module : public Value { return it; } + [[nodiscard]] auto getLastFunctionIt() { + auto it{ m_Functions.end() }; + it--; + return it; + } + [[nodiscard]] auto getLastBBIt() { auto it{ m_Functions.back().end() }; it--; diff --git a/include/bort/Lex/Tokens.def b/include/bort/Lex/Tokens.def index 9188807..801e096 100644 --- a/include/bort/Lex/Tokens.def +++ b/include/bort/Lex/Tokens.def @@ -28,6 +28,7 @@ PUNCT(LBracket, "[") PUNCT(RBracket, "]") PUNCT(Comma, ",") PUNCT(Semicolon, ";") +PUNCT(Colon, ":") PUNCT(Amp, "&") PUNCT(Assign, "=") PUNCT(Equals, "==") @@ -73,6 +74,7 @@ KEYWORD(while) KEYWORD(for) KEYWORD(return) KEYWORD(sizeof) +KEYWORD(goto) // Preprocessor PPTOK(define) diff --git a/include/bort/Parse/Parser.hpp b/include/bort/Parse/Parser.hpp index 696a812..94bac6a 100644 --- a/include/bort/Parse/Parser.hpp +++ b/include/bort/Parse/Parser.hpp @@ -5,8 +5,10 @@ #include "bort/AST/ExpressionNode.hpp" #include "bort/AST/FunctionCallExpr.hpp" #include "bort/AST/FunctionDecl.hpp" +#include "bort/AST/GotoStmt.hpp" #include "bort/AST/IfStmt.hpp" #include "bort/AST/InitializerList.hpp" +#include "bort/AST/LabelStmt.hpp" #include "bort/AST/NumberExpr.hpp" #include "bort/AST/ReturnStmt.hpp" #include "bort/AST/VarDecl.hpp" @@ -103,7 +105,8 @@ class Parser { /// -> whileStatement \n /// -> returnStatement \n /// -> breakStatement \n - /// -> continueStatement + /// -> continueStatement \n + /// -> labelStatement auto parseStatement() -> Ref; /// breakStatement -> 'break' ';' auto parseBreakStatement() -> Ref; @@ -120,6 +123,10 @@ class Parser { auto parseForStatement() -> Ref; /// returnStatement -> 'return' (expr)? ';' auto parseReturnStatement() -> Ref; + /// labelStatement -> identifier ':' ';' + auto parseLabelStatement() -> Ref; + /// gotoStatement -> 'goto' identifier ';' + auto parseGotoStatement() -> Ref; private: void disableDiagnostics() { @@ -137,10 +144,7 @@ class Parser { return *m_CurTokIter; } - auto invalidNode() -> std::nullptr_t { - m_ASTInvalid = true; - return nullptr; - } + auto invalidNode() -> std::nullptr_t; inline void consumeToken() { m_CurTokIter++; diff --git a/launch.json b/launch.json index aa2ed9f..5af5810 100644 --- a/launch.json +++ b/launch.json @@ -10,7 +10,7 @@ "--emit-ir", "-o", "-", - "${workspaceFolder}/tests/corpus/loops.c" + "${workspaceFolder}/tests/corpus/goto.c" ], "stopAtEntry": false, "cwd": "${workspaceFolder}", diff --git a/src/AST/Visitors/ASTPrinter.cpp b/src/AST/Visitors/ASTPrinter.cpp index 32ce250..99593a9 100644 --- a/src/AST/Visitors/ASTPrinter.cpp +++ b/src/AST/Visitors/ASTPrinter.cpp @@ -38,6 +38,8 @@ static constexpr cul::BiMap s_NodeKindNames{ [](auto&& selector) { .Case(NodeKind::ReturnStmt, "ReturnStmt") .Case(NodeKind::BreakStmt, "BreakStmt") .Case(NodeKind::ContinueStmt, "ContinueStmt") + .Case(NodeKind::LabelStmt, "LabelStmt") + .Case(NodeKind::GotoStmt, "GotoStmt") .Case(NodeKind::Block, "Block") .Case(NodeKind::ASTRoot, "ASTRoot"); } }; @@ -202,4 +204,13 @@ void ASTPrinter::visit(const Ref& continueStmtNode) { dumpNodeInfo(continueStmtNode); } +void ASTPrinter::visit(const Ref& labelStmtNode) { + dumpNodeInfo(labelStmtNode); + dump("Name", labelStmtNode->getLabelName()); +} + +void ASTPrinter::visit(const Ref& gotoStmtNode) { + dumpNodeInfo(gotoStmtNode); + dump("Target", gotoStmtNode->getTargetLabel()); +} } // namespace bort::ast diff --git a/src/AST/Visitors/SymbolResolutionVisitor.cpp b/src/AST/Visitors/SymbolResolutionVisitor.cpp index 6d635d1..554a669 100644 --- a/src/AST/Visitors/SymbolResolutionVisitor.cpp +++ b/src/AST/Visitors/SymbolResolutionVisitor.cpp @@ -1,13 +1,32 @@ #include "bort/AST/Visitors/SymbolResolutionVisitor.hpp" #include "bort/AST/Block.hpp" +#include "bort/AST/LabelStmt.hpp" #include "bort/AST/VarDecl.hpp" #include "bort/AST/VariableExpr.hpp" #include "bort/AST/Visitors/ASTVisitor.hpp" #include "bort/Basic/Casts.hpp" #include "bort/CLI/IO.hpp" +#include "bort/Frontend/FrontEndInstance.hpp" +#include +#include namespace bort::ast { +struct CollectLabelDefinitionsVisitor : StructureAwareASTVisitor { + void visit(const Ref& labelStmtNode) override { + if (DefinedLabels.contains(labelStmtNode->getLabelName())) { + Diagnostic::emitError( + getASTRoot()->getNodeDebugInfo(labelStmtNode).token, + "Label '{}' already defined", labelStmtNode->getLabelName()); + markASTInvalid(); + return; + } + DefinedLabels.insert(labelStmtNode->getLabelName()); + } + + std::unordered_set DefinedLabels; +}; + Scope::Scope(Ref enclosingScope) : m_EnclosingScope{ std::move(enclosingScope) }, m_Name{ std::nullopt } { @@ -49,6 +68,14 @@ SymbolResolutionVisitor::SymbolResolutionVisitor() : m_CurrentScope{ makeRef(nullptr, "Global") } { } +void SymbolResolutionVisitor::visit(const Ref& astRoot) { + CollectLabelDefinitionsVisitor visitor; + visitor.SAVisit(astRoot); + m_DefinedLabels = std::move(visitor.DefinedLabels); + + StructureAwareASTVisitor::visit(astRoot); +} + void SymbolResolutionVisitor::visit(const Ref& varNode) { if (varNode->isResolved()) { return; @@ -125,6 +152,15 @@ void SymbolResolutionVisitor::visit( pop(); } +void SymbolResolutionVisitor::visit(const Ref& gotoStmt) { + if (!m_DefinedLabels.contains(gotoStmt->getTargetLabel())) { + Diagnostic::emitError(getASTRoot()->getNodeDebugInfo(gotoStmt).token, + "Undefined label: {}", + gotoStmt->getTargetLabel()); + markASTInvalid(); + } +} + void SymbolResolutionVisitor::visit( const Ref& functionCallExpr) { if (functionCallExpr->isResolved()) { diff --git a/src/Codegen/RISCVCodegen.cpp b/src/Codegen/RISCVCodegen.cpp index e1a8443..3a78182 100644 --- a/src/Codegen/RISCVCodegen.cpp +++ b/src/Codegen/RISCVCodegen.cpp @@ -246,8 +246,8 @@ class InstructionChoicePass : public InstructionVisitorBase { .Case(TokenKind::Amp, "and") .Case(TokenKind::Pipe, "or") .Case(TokenKind::Xor, "xor") - .Case(TokenKind::Xor, "sll") - .Case(TokenKind::Xor, "sra"); + .Case(TokenKind::LShift, "sll") + .Case(TokenKind::RShift, "sra"); } }; bort_assert(s_OpInstNames.Find(opInst->getOp()).has_value(), diff --git a/src/IR/IRCodegen.cpp b/src/IR/IRCodegen.cpp index d04be3c..e754d17 100644 --- a/src/IR/IRCodegen.cpp +++ b/src/IR/IRCodegen.cpp @@ -1,4 +1,5 @@ #include "bort/IR/IRCodegen.hpp" +#include "bort/AST/ASTDebugInfo.hpp" #include "bort/AST/ASTNode.hpp" #include "bort/AST/BinOpExpr.hpp" #include "bort/AST/ExpressionNode.hpp" @@ -17,6 +18,7 @@ #include "bort/IR/GlobalValue.hpp" #include "bort/IR/LoadInst.hpp" #include "bort/IR/Metadata.hpp" +#include "bort/IR/Module.hpp" #include "bort/IR/MoveInst.hpp" #include "bort/IR/OpInst.hpp" #include "bort/IR/Register.hpp" @@ -64,6 +66,23 @@ struct BrToLoopStartMDTag : public MDTag { } }; +struct GotoUnresolvedLabelMD : Metadata { + GotoUnresolvedLabelMD(IRFuncIter funcIter, std::string label, + ast::ASTDebugInfo stmtDebugInfo) + : FuncIter{ funcIter }, + Label{ std::move(label) }, + StmtDebugInfo{ std::move(stmtDebugInfo) } { + } + + [[nodiscard]] auto toString() const -> std::string override { + return fmt::format("goto_unresolved_label"); + } + + IRFuncIter FuncIter; + std::string Label; + ast::ASTDebugInfo StmtDebugInfo; +}; + auto IRCodegen::genBranchFromCondition( const Ref& cond, bool negate) -> Ref { @@ -123,9 +142,45 @@ auto IRCodegen::genArrayPtr(const ValueRef& arr) return { ptrTy, arrPtr }; } +static void resolveGotoBranch(const Module& M, + const Ref& br, + GotoUnresolvedLabelMD* GUL) { + for (auto funcIt{ M.begin() }; funcIt != M.end(); ++funcIt) { + for (auto&& bbTarget : *funcIt) { + std::cerr << bbTarget.getName() << std::endl; + if (bbTarget.getName() == GUL->Label) { + if (funcIt != GUL->FuncIter) { + Diagnostic::emitWarning(GUL->StmtDebugInfo.token, + "Goto outside function"); + } + br->setTarget(&bbTarget); + return; + } + } + } + bort_assert(false, "Label resolution failed"); +} + +void IRCodegen::resolveGotoLabels() { + for (auto&& func : m_Module) { + for (auto&& bb : func) { + for (auto&& inst : bb) { + if (auto* GUL{ inst->getMDNode() }) { + auto br{ dynCastRef(inst) }; + bort_assert(br, "Branch expected"); + resolveGotoBranch(m_Module, br, GUL); + br->removeMDNode(); + } + } + } + } +} + void IRCodegen::codegen(const Ref& ast) { m_ASTRoot = ast; genericVisit(ast); + + resolveGotoLabels(); m_Module.revalidateBasicBlocks(); } @@ -453,4 +508,17 @@ auto IRCodegen::visit(const Ref& /*continueStmt*/) return nullptr; } +auto IRCodegen::visit(const Ref& labelStmt) -> ValueRef { + pushBB("", labelStmt->getLabelName()); + return nullptr; +} + +auto IRCodegen::visit(const Ref& gotoStmt) -> ValueRef { + auto newInst{ addInstruction(makeRef()) }; + newInst->addMDNode(GotoUnresolvedLabelMD{ + m_Module.getLastFunctionIt(), gotoStmt->getTargetLabel(), + m_ASTRoot->getNodeDebugInfo(gotoStmt) }); + return nullptr; +} + } // namespace bort::ir diff --git a/src/IR/Module.cpp b/src/IR/Module.cpp index c965f20..cb6fa24 100644 --- a/src/IR/Module.cpp +++ b/src/IR/Module.cpp @@ -1,5 +1,6 @@ #include "bort/IR/Module.hpp" #include "bort/Basic/Assert.hpp" +#include "bort/IR/BranchInst.hpp" #include "bort/IR/GlobalValue.hpp" using namespace bort::ir; @@ -45,12 +46,27 @@ auto bort::ir::Module::addInstruction(Ref instruction) return lastBB.getInstructions().back(); } +static auto hasReferencingBranch(const IRFunction& func, + BasicBlock* bb) -> bool { + for (auto&& BB : func) { + for (auto&& inst : BB) { + if (auto br{ dynCastRef(inst) }) { + if (br->getTarget() == bb) { + return true; + } + } + } + } + return false; +} + void bort::ir::Module::revalidateBasicBlocks() { for (auto&& func : m_Functions) { for (auto it{ func.begin() }; it != func.end();) { /// last BB can be empty if (it->getInstructions().empty() && - ++decltype(it){ it } != func.end()) { + ++decltype(it){ it } != func.end() && + !hasReferencingBranch(func, &*it)) { it = func.erase(it); } else { it++; @@ -59,7 +75,7 @@ void bort::ir::Module::revalidateBasicBlocks() { } } -auto bort::ir::Module::addGlobal(const Ref& global) +auto bort::ir::Module::addGlobal(const Ref& global) -> Ref { m_Globals[global->getName()] = global; return m_Globals.at(global->getName()); @@ -76,7 +92,7 @@ auto bort::ir::Module::getGlobalVariable(const std::string& name) auto bort::ir::Module::getGlobalVariable(const Ref& variable) -> Ref { - return getGlobalVariable(variable->getName()); + return getGlobalVariable(variable->getName()); } void bort::ir::Module::addBasicBlock(std::string name) { @@ -86,4 +102,3 @@ void bort::ir::Module::addBasicBlock(std::string name) { void bort::ir::Module::addFunction(Ref function) { m_Functions.emplace_back(std::move(function)); } - diff --git a/src/Parse/Parser.cpp b/src/Parse/Parser.cpp index c767e9e..8083688 100644 --- a/src/Parse/Parser.cpp +++ b/src/Parse/Parser.cpp @@ -7,6 +7,7 @@ #include "bort/AST/ExpressionNode.hpp" #include "bort/AST/ExpressionStmt.hpp" #include "bort/AST/FunctionCallExpr.hpp" +#include "bort/AST/GotoStmt.hpp" #include "bort/AST/IndexationExpr.hpp" #include "bort/AST/InitializerList.hpp" #include "bort/AST/NumberExpr.hpp" @@ -19,6 +20,7 @@ #include "bort/Basic/Casts.hpp" #include "bort/Basic/Ref.hpp" #include "bort/CLI/IO.hpp" +#include "bort/Frontend/FrontEndInstance.hpp" #include "bort/Frontend/Symbol.hpp" #include "bort/Frontend/Type.hpp" #include "bort/Lex/Token.hpp" @@ -562,6 +564,13 @@ auto Parser::parseStatement() -> Ref { return parseBreakStatement(); case TokenKind::KW_continue: return parseContinueStatement(); + case TokenKind::Identifier: + if (lookahead(1).is(TokenKind::Colon)) { + return parseLabelStatement(); + } + break; + case TokenKind::KW_goto: + return parseGotoStatement(); default: break; } @@ -649,6 +658,12 @@ auto Parser::lookahead(uint32_t offset) const -> const Token& { return *iter; } +auto Parser::invalidNode() -> std::nullptr_t { + m_ASTInvalid = true; + throw FrontEndFatalError{ "Invalid syntax" }; + return nullptr; +} + auto Parser::parseIfStatement() -> Ref { bort_assert(curTok().is(TokenKind::KW_if), "Expected 'if'"); auto ifTok{ curTok() }; @@ -793,4 +808,46 @@ auto Parser::parseForStatement() -> Ref { return outerBlock; } +auto Parser::parseLabelStatement() -> Ref { + bort_assert(curTok().is(TokenKind::Identifier), "Expected identifier"); + auto labelTok{ curTok() }; + consumeToken(); + if (curTok().isNot(TokenKind::Colon)) { + Diagnostic::emitError(curTok(), "Expected ':'"); + return invalidNode(); + } + consumeToken(); + + if (curTok().is(TokenKind::Semicolon)) { + consumeToken(); + } + + return m_ASTRoot->registerNode( + ast::ASTDebugInfo{ labelTok }, + std::string{ labelTok.getStringView() }); +} + +auto Parser::parseGotoStatement() -> Ref { + bort_assert(curTok().is(TokenKind::KW_goto), "Expected 'goto'"); + auto gotoTok{ curTok() }; + consumeToken(); + + if (curTok().isNot(TokenKind::Identifier)) { + Diagnostic::emitError(curTok(), "Expected identifier"); + return invalidNode(); + } + auto labelTok{ curTok() }; + consumeToken(); + + if (curTok().isNot(TokenKind::Semicolon)) { + Diagnostic::emitError(curTok(), "Expected ';'"); + return invalidNode(); + } + consumeToken(); + + return m_ASTRoot->registerNode( + ast::ASTDebugInfo{ gotoTok }, + std::string{ labelTok.getStringView() }); +} + } // namespace bort diff --git a/tests/corpus/goto.c b/tests/corpus/goto.c new file mode 100644 index 0000000..3f82ba8 --- /dev/null +++ b/tests/corpus/goto.c @@ -0,0 +1,16 @@ +int main() { + int x = 3; + int y; + y = x >> 2; + + if (y < x) { + goto my_label; + } + goto my_label2; + +my_label:; + y = x; + +my_label2:; + x = y; +}