From 5ab23aa1a3294e56a480f0b0dbad1dda5af50d6f Mon Sep 17 00:00:00 2001 From: lidotcircle Date: Tue, 1 Aug 2023 08:30:50 +0000 Subject: [PATCH] scalc llvm codegen --- .github/workflows/build.yml | 48 +++ example/SimpleCalculator/CMakeLists.txt | 8 + example/SimpleCalculator/include/scalc/ast.h | 87 ++++- .../include/scalc/lexer_parser.h | 2 + .../include/scalc/llvm_visitor.h | 51 +++ .../SimpleCalculator/include/scalc/parser.h | 1 + example/SimpleCalculator/lib/ast.cpp | 9 +- example/SimpleCalculator/lib/context.cpp | 1 - example/SimpleCalculator/lib/lexer_parser.cpp | 10 + example/SimpleCalculator/lib/llvm_visitor.cpp | 328 ++++++++++++++++++ example/SimpleCalculator/lib/parser.cpp | 38 ++ example/SimpleCalculator/lib/tokenizer.cpp | 3 +- example/SimpleCalculator/src/main.cpp | 32 +- 13 files changed, 605 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/build.yml create mode 100644 example/SimpleCalculator/include/scalc/llvm_visitor.h create mode 100644 example/SimpleCalculator/lib/llvm_visitor.cpp diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..ce52f92 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,48 @@ +name: Test +on: + push: + branches: + - master + tags: + - v[0-9]+.[0-9] + +jobs: + BuildAndTestUbuntu: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + + - name: release name + id: release + run: | + branch=$(echo ${{github.ref}} | sed 's/refs\/\(heads\|tags\)\///g') && \ + release=$([[ "$branch" =~ v[0-9.]+ ]] && echo $branch || echo $branch) && \ + echo "release_name=$release" && \ + printf "::set-output name=release::%s\n" "$release" + + - name: install clang + run: | + wget https://apt.llvm.org/llvm.sh && \ + chmod +x llvm.sh && \ + sudo ./llvm.sh 17 && \ + sudo apt install clang-tidy && \ + echo "clang-17: $(which clang-17), clang-tidy: $(which clang-tidy-17)" + + - name: Release Build + run: | + export CC=$(which clang-17) && export CXX=$(which clang++-17) && \ + mkdir build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCXX_VERSION=17 && make + + - name: Test + run: | + cd build && ctest + + - name: release + uses: ncipollo/release-action@v1 + with: + allowUpdates: true + artifacts: "./build/example/SimpleCalculator/scalc" + token: ${{ secrets.GITHUB_TOKEN }} + tag: ${{ steps.release.outputs.release }} diff --git a/example/SimpleCalculator/CMakeLists.txt b/example/SimpleCalculator/CMakeLists.txt index 4aa5b84..915373c 100644 --- a/example/SimpleCalculator/CMakeLists.txt +++ b/example/SimpleCalculator/CMakeLists.txt @@ -1,8 +1,16 @@ +find_package(LLVM REQUIRED CONFIG) +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +include_directories(${LLVM_INCLUDE_DIRS}) +add_definitions(${LLVM_DEFINITIONS}) +llvm_map_components_to_libnames(llvm_libs support core irreader) file(GLOB_RECURSE scalclib_SOURCES CONFIGURE_DEPENDS ./lib/**.cpp) add_library(scalclib STATIC ${scalclib_SOURCES}) set_property(TARGET scalclib PROPERTY CXX_STANDARD 17) target_link_libraries(scalclib PUBLIC dcparse) +target_link_libraries(scalclib PUBLIC ${llvm_libs}) target_include_directories(scalclib PUBLIC "${CMAKE_CURRENT_LIST_DIR}/include") add_executable(scalc ./src/main.cpp) diff --git a/example/SimpleCalculator/include/scalc/ast.h b/example/SimpleCalculator/include/scalc/ast.h index 071374d..9e1b576 100644 --- a/example/SimpleCalculator/include/scalc/ast.h +++ b/example/SimpleCalculator/include/scalc/ast.h @@ -2,13 +2,42 @@ #define _SIMPLE_CALCULATOR_AST_H_ #include "parser/parser.h" -#include "./token.h" #include #include +#include using ASTNodeParserContext = std::weak_ptr; +class ASTNodeExpr; +class ASTNodeExprList; +class ASTNodeBlockStat; +class ASTNodeStat; +class ASTNodeStatList; +class ASTNodeIFStat; +class ASTNodeFunctionDef; +class ASTNodeFunctionDecl; +class ASTNodeCalcUnit; +using UnitItem = std::variant< + std::shared_ptr, + std::shared_ptr, + std::shared_ptr>; + + +class ASTNodeVisitor { +public: + virtual void visitExpr(const ASTNodeExpr&) = 0; + virtual void visitExprStat(const ASTNodeExprList&) = 0; + virtual void visitBlockStat(const ASTNodeStatList&) = 0; + virtual void visitStatList(const std::vector>&) = 0; + virtual void visitForStat(const ASTNodeExprList*, const ASTNodeExpr*, const ASTNodeExprList*, const ASTNodeStat&) = 0; + virtual void visitIfStat(const ASTNodeExpr&, const ASTNodeStat&, const ASTNodeStat*) = 0; + virtual void visitReturnStat(const ASTNodeExpr&) = 0; + virtual void visitFuncDef(const std::string&, const std::vector&, const ASTNodeBlockStat&) = 0; + virtual void visitFuncDecl(const std::string&, const std::vector&) = 0; + virtual void visitCalcUnit(const std::vector& items) = 0; +}; + class ASTNode { private: ASTNodeParserContext m_parser_context; @@ -17,6 +46,7 @@ class ASTNode { inline ASTNode(ASTNodeParserContext p): m_parser_context(p) {} inline ASTNodeParserContext context() { return this->m_parser_context; } virtual ~ASTNode() = default; + virtual void accept(ASTNodeVisitor& visitor) const = 0; }; @@ -26,6 +56,7 @@ class ASTNodeExpr : public ASTNode { virtual double evaluate() = 0; virtual bool used() { return false; } + void accept(ASTNodeVisitor& visitor) const override { visitor.visitExpr(*this); } }; class UnaryOperatorExpr: public ASTNodeExpr { @@ -48,6 +79,9 @@ class UnaryOperatorExpr: public ASTNodeExpr { ASTNodeExpr(c), m_operator(optype), m_expr(expr) {} + UnaryOperatorType op() const { return m_operator; } + const ASTNodeExpr& cc() const { return *m_expr; } + virtual double evaluate() override; }; @@ -76,6 +110,10 @@ class BinaryOperatorExpr: public ASTNodeExpr { ): ASTNodeExpr(c), m_operator(optype), m_left(std::move(left)), m_right(std::move(right)) {} + + BinaryOperatorType op() const { return m_operator; } + const ASTNodeExpr& left() const { return *m_left; } + const ASTNodeExpr& right() const { return *m_right; } virtual double evaluate() override; bool used() override { return m_operator == ASSIGNMENT; } @@ -129,6 +167,7 @@ class ASTNodeExprList: public ASTNode, private std::vector StatementList() const { return this->_statlist; } virtual void execute() override; + void accept(ASTNodeVisitor& visitor) const override { visitor.visitBlockStat(*_statlist); } }; class ASTNodeExprStat: public ASTNodeStat @@ -213,6 +255,7 @@ class ASTNodeExprStat: public ASTNodeStat inline const std::shared_ptr exprList() const { return this->_exprlist; } virtual void execute() override; + void accept(ASTNodeVisitor& visitor) const override { visitor.visitExprStat(*_exprlist); } }; class ASTNodeReturnStat: public ASTNodeStat @@ -229,6 +272,7 @@ class ASTNodeReturnStat: public ASTNodeStat inline const std::shared_ptr expr() const { return this->_expr; } virtual void execute() override; + void accept(ASTNodeVisitor& visitor) const override { visitor.visitReturnStat(*_expr); } }; class ASTNodeIFStat: public ASTNodeStat @@ -251,6 +295,7 @@ class ASTNodeIFStat: public ASTNodeStat inline std::shared_ptr condition() { return this->_cond; } virtual void execute() override; + void accept(ASTNodeVisitor& visitor) const override { visitor.visitIfStat(*_cond, *_truestat, _falsestat.get()); } }; class ASTNodeFORStat: public ASTNodeStat @@ -276,6 +321,7 @@ class ASTNodeFORStat: public ASTNodeStat inline std::shared_ptr post() { return this->_post; } virtual void execute() override; + void accept(ASTNodeVisitor& visitor) const override { visitor.visitForStat(_pre.get(), _cond.get(), _post.get(), *_stat); } }; /* @@ -312,6 +358,8 @@ class ASTNodeArgList: public ASTNode, private std::vector using container_t::size; using container_t::operator[]; using container_t::push_back; + + virtual void accept(ASTNodeVisitor& visitor) const { assert(false); } }; class ASTNodeFunctionDef: public ASTNode @@ -334,19 +382,52 @@ class ASTNodeFunctionDef: public ASTNode inline std::shared_ptr blockStat() { return this->blockstat; } void call(std::vector parameters); + void accept(ASTNodeVisitor& visitor) const override { + std::vector args; + for (auto& a: *arglist) args.push_back(a); + visitor.visitFuncDef(static_cast(func.get())->id(), args, *blockstat); + } +}; + +class ASTNodeFunctionDecl: public ASTNode +{ +private: + std::shared_ptr func; + std::shared_ptr arglist; + +public: + inline ASTNodeFunctionDecl( + ASTNodeParserContext c, + std::shared_ptr func, + std::shared_ptr args): + ASTNode(c), func(func), arglist(args) {} + + inline std::shared_ptr function() { return this->func; } + inline std::shared_ptr argList() { return this->arglist; } + + void accept(ASTNodeVisitor& visitor) const override { + std::vector args; + for (auto& a: *arglist) args.push_back(a); + visitor.visitFuncDecl(static_cast(func.get())->id(), args); + } }; class ASTNodeCalcUnit: public ASTNode { private: - std::vector> functions; - std::vector> statements; + std::vector m_unititem; public: inline ASTNodeCalcUnit(ASTNodeParserContext c): ASTNode(c) {} void push_function(std::shared_ptr func) ; + void push_function_decl(std::shared_ptr func) ; void push_statement(std::shared_ptr stat); + + virtual void accept(ASTNodeVisitor& visitor) const + { + visitor.visitCalcUnit(m_unititem); + } }; #endif // _SIMPLE_CALCULATOR_AST_H_ diff --git a/example/SimpleCalculator/include/scalc/lexer_parser.h b/example/SimpleCalculator/include/scalc/lexer_parser.h index f3b59fc..120c166 100644 --- a/example/SimpleCalculator/include/scalc/lexer_parser.h +++ b/example/SimpleCalculator/include/scalc/lexer_parser.h @@ -9,6 +9,7 @@ class CalcLexerParser { private: CalcParser parser; CalcLexer lexer; + std::shared_ptr m_calunit; public: CalcLexerParser(bool execute); @@ -16,6 +17,7 @@ class CalcLexerParser { void feed(char c); std::shared_ptr end(); void reset(); + std::string genllvm(const std::string& modulename) const; std::shared_ptr getContext(); inline void setDebugStream(std::ostream& out) { this->parser.setDebugStream(out); } diff --git a/example/SimpleCalculator/include/scalc/llvm_visitor.h b/example/SimpleCalculator/include/scalc/llvm_visitor.h new file mode 100644 index 0000000..756bc62 --- /dev/null +++ b/example/SimpleCalculator/include/scalc/llvm_visitor.h @@ -0,0 +1,51 @@ +#include "ast.h" +#include +#include +#include +#include +#include +#include +#include +#include + + +class ASTNodeVisitorLLVMGen: public ASTNodeVisitor { +private: + std::unique_ptr m_context; + std::unique_ptr m_module; + std::unique_ptr> m_builder; + std::vector>> m_namedValues; + llvm::Value* m_prevValue; + bool m_globalscope; + + llvm::Value* lookup_var(const std::string& name) + { + for (size_t i=m_namedValues.size();i>0;i--) { + auto& k = m_namedValues.at(i - 1); + if (k.count(name)) { + if (std::holds_alternative(k.at(name))) { + return std::get<0>(k.at(name)); + } else { + return std::get<1>(k.at(name)); + } + } + } + return nullptr; + } + +public: + ASTNodeVisitorLLVMGen(const std::string& filename); + + void visitExpr(const ASTNodeExpr&) override; + void visitExprStat(const ASTNodeExprList&) override; + void visitBlockStat(const ASTNodeStatList&) override; + void visitStatList(const std::vector>&) override; + void visitForStat(const ASTNodeExprList*, const ASTNodeExpr*, const ASTNodeExprList*, const ASTNodeStat&) override; + void visitIfStat(const ASTNodeExpr&, const ASTNodeStat&, const ASTNodeStat*) override; + void visitReturnStat(const ASTNodeExpr&) override; + void visitFuncDef(const std::string&, const std::vector&, const ASTNodeBlockStat&) override; + void visitFuncDecl(const std::string&, const std::vector&) override; + void visitCalcUnit(const std::vector&) override; + + std::string codegen() const; +}; diff --git a/example/SimpleCalculator/include/scalc/parser.h b/example/SimpleCalculator/include/scalc/parser.h index 223e092..d9787a6 100644 --- a/example/SimpleCalculator/include/scalc/parser.h +++ b/example/SimpleCalculator/include/scalc/parser.h @@ -9,6 +9,7 @@ TENTRY(CalcUnit) \ TENTRY(ArgList) \ TENTRY(FunctionDef) \ + TENTRY(FunctionDecl) \ \ TENTRY(IfStatement) \ TENTRY(ForStatement) \ diff --git a/example/SimpleCalculator/lib/ast.cpp b/example/SimpleCalculator/lib/ast.cpp index 910cd59..74093f2 100644 --- a/example/SimpleCalculator/lib/ast.cpp +++ b/example/SimpleCalculator/lib/ast.cpp @@ -41,7 +41,7 @@ class ASTNodeFunctionExec: public SCalcFunction void ASTNodeCalcUnit::push_function(shared_ptr func) { - this->functions.push_back(func); + this->m_unititem.push_back(func); auto funcname = func->function(); auto funcnamet = dynamic_pointer_cast(funcname); @@ -52,9 +52,14 @@ void ASTNodeCalcUnit::push_function(shared_ptr func) context->add_function(funcnamet->id(), make_shared(func)); } +void ASTNodeCalcUnit::push_function_decl(shared_ptr func) +{ + this->m_unititem.push_back(func); +} + void ASTNodeCalcUnit::push_statement(shared_ptr stat) { - this->statements.push_back(stat); + this->m_unititem.push_back(stat); auto context = this->context(); auto ptr = context.lock(); diff --git a/example/SimpleCalculator/lib/context.cpp b/example/SimpleCalculator/lib/context.cpp index 75431c0..59279d6 100644 --- a/example/SimpleCalculator/lib/context.cpp +++ b/example/SimpleCalculator/lib/context.cpp @@ -1,5 +1,4 @@ #include "scalc/context.h" -#include "scalc/ast.h" #include "scalc/scalc_error.h" #include "scalc/defer.hpp" #include diff --git a/example/SimpleCalculator/lib/lexer_parser.cpp b/example/SimpleCalculator/lib/lexer_parser.cpp index 45dd484..15cae0d 100644 --- a/example/SimpleCalculator/lib/lexer_parser.cpp +++ b/example/SimpleCalculator/lib/lexer_parser.cpp @@ -1,4 +1,5 @@ #include "scalc/lexer_parser.h" +#include "scalc/llvm_visitor.h" using namespace std; @@ -22,6 +23,7 @@ std::shared_ptr CalcLexerParser::end() auto nonterm = this->parser.end(); auto calcunit = dynamic_pointer_cast(nonterm); + m_calunit = calcunit->astnode; return calcunit->astnode; } @@ -31,6 +33,14 @@ void CalcLexerParser::reset() this->parser.reset(); } +std::string CalcLexerParser::genllvm(const std::string& modulename) const +{ + assert(m_calunit); + ASTNodeVisitorLLVMGen visitor(modulename); + m_calunit->accept(visitor); + return visitor.codegen(); +} + std::shared_ptr CalcLexerParser::getContext() { auto ctx = this->parser.getContext(); diff --git a/example/SimpleCalculator/lib/llvm_visitor.cpp b/example/SimpleCalculator/lib/llvm_visitor.cpp new file mode 100644 index 0000000..0221d93 --- /dev/null +++ b/example/SimpleCalculator/lib/llvm_visitor.cpp @@ -0,0 +1,328 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "scalc/llvm_visitor.h" +#include "scalc/scalc_error.h" +#include "llvm/Support/raw_ostream.h" + +ASTNodeVisitorLLVMGen::ASTNodeVisitorLLVMGen(const std::string& filename) : + m_context(std::make_unique()), + m_module(std::make_unique(filename, *m_context)), + m_builder(std::make_unique>(*m_context)), + m_namedValues(1), + m_prevValue(nullptr), + m_globalscope(true) +{ +} + +void ASTNodeVisitorLLVMGen::visitExpr(const ASTNodeExpr& expr) +{ + if (dynamic_cast(&expr)) { + auto id = dynamic_cast(&expr)->id(); + auto v = this->lookup_var(id); + if (!v) { + throw SCalcError(id + " is not declared"); + } + m_prevValue = m_builder->CreateLoad(llvm::Type::getDoubleTy(*m_context), v, id); + } else if (dynamic_cast(&expr)) { + auto& e = dynamic_cast(expr); + e.right().accept(*this); + auto right = m_prevValue; + auto left = right; + if (e.op() != BinaryOperatorExpr::ASSIGNMENT) { + e.left().accept(*this); + left = m_prevValue; + } + switch (e.op()) { + case BinaryOperatorExpr::PLUS: + m_prevValue = m_builder->CreateFAdd(left, right, "fplus_r"); + break; + case BinaryOperatorExpr::MINUS: + m_prevValue = m_builder->CreateFSub(left, right, "fminus_r"); + break; + case BinaryOperatorExpr::MULTIPLY: + m_prevValue = m_builder->CreateFMul(left, right, "fmul_r"); + break; + case BinaryOperatorExpr::DIVISION: + m_prevValue = m_builder->CreateFDiv(left, right, "fdiv_r"); + break; + case BinaryOperatorExpr::REMAINDER: + m_prevValue = m_builder->CreateFRem(left, right, "frem_r"); + break; + case BinaryOperatorExpr::CARET: + { + auto func = m_module->getFunction("pow"); + assert(func); + m_prevValue = m_builder->CreateCall(func, {left, right}, "fpow_r"); + } break; + case BinaryOperatorExpr::ASSIGNMENT: + { + auto id = dynamic_cast(&e.left()); + if (!id) throw SCalcError("expected an id"); + auto v = this->lookup_var(id->id()); + if (!v) { + if (m_globalscope) { + auto gv = new llvm::GlobalVariable(llvm::Type::getDoubleTy(*m_context), false, llvm::GlobalVariable::PrivateLinkage, nullptr, id->id()); + m_module->insertGlobalVariable(gv); + m_namedValues.back().insert({id->id(), gv}); + v = gv; + } else { + auto addr = m_builder->CreateAlloca(llvm::Type::getDoubleTy(*m_context), nullptr, "local"); + m_namedValues.back().insert({id->id(), addr}); + v = addr; + } + } + m_builder->CreateStore(right, v); + m_prevValue = m_builder->CreateLoad(right->getType(), v, "assign_r"); + } break; + case BinaryOperatorExpr::EQUAL: + m_prevValue = m_builder->CreateFCmpOEQ(left, right, "cmp"); + m_prevValue = m_builder->CreateCast(llvm::Instruction::CastOps::UIToFP, m_prevValue, llvm::Type::getDoubleTy(*m_context), "cast2fp_r"); + break; + case BinaryOperatorExpr::NOTEQUAL: + m_prevValue = m_builder->CreateFCmpONE(left, right, "cmp"); + m_prevValue = m_builder->CreateCast(llvm::Instruction::CastOps::UIToFP, m_prevValue, llvm::Type::getDoubleTy(*m_context), "cast2fp_r"); + break; + case BinaryOperatorExpr::GREATERTHAN: + m_prevValue = m_builder->CreateFCmpOGT(left, right, "cmp"); + m_prevValue = m_builder->CreateCast(llvm::Instruction::CastOps::UIToFP, m_prevValue, llvm::Type::getDoubleTy(*m_context), "cast2fp_r"); + break; + case BinaryOperatorExpr::LESSTHAN: + m_prevValue = m_builder->CreateFCmpOLT(left, right, "cmp"); + m_prevValue = m_builder->CreateCast(llvm::Instruction::CastOps::UIToFP, m_prevValue, llvm::Type::getDoubleTy(*m_context), "cast2fp_r"); + break; + case BinaryOperatorExpr::GREATEREQUAL: + m_prevValue = m_builder->CreateFCmpOGE(left, right, "cmp"); + m_prevValue = m_builder->CreateCast(llvm::Instruction::CastOps::UIToFP, m_prevValue, llvm::Type::getDoubleTy(*m_context), "cast2fp_r"); + break; + case BinaryOperatorExpr::LESSEQUAL: + m_prevValue = m_builder->CreateFCmpOLE(left, right, "cmp"); + m_prevValue = m_builder->CreateCast(llvm::Instruction::CastOps::UIToFP, m_prevValue, llvm::Type::getDoubleTy(*m_context), "cast2fp_r"); + break; + } + } else if (dynamic_cast(&expr)) { + auto& e = dynamic_cast(expr); + auto cn = dynamic_cast(&e.cc()); + if (!cn) + throw SCalcError("require an lvalue"); + auto addr = this->lookup_var(cn->id()); + if (!addr) + throw SCalcError("variable '" + cn->id() + "' is not defined"); + auto oldvalue = m_builder->CreateLoad(llvm::Type::getDoubleTy(*m_context), addr, cn->id()); + switch (e.op()) { + case UnaryOperatorExpr::PRE_INC: + { + auto v = m_builder->CreateFAdd(oldvalue, llvm::ConstantFP::get(*m_context, llvm::APFloat(1.0)), "inc1"); + m_builder->CreateStore(v, addr); + m_prevValue = v; + } break; + case UnaryOperatorExpr::PRE_DEC: + { + auto v = m_builder->CreateFAdd(oldvalue, llvm::ConstantFP::get(*m_context, llvm::APFloat(-1.0)), "dec1"); + m_builder->CreateStore(v, addr); + m_prevValue = v; + } break; + case UnaryOperatorExpr::POS_INC: + { + auto v = m_builder->CreateFAdd(oldvalue, llvm::ConstantFP::get(*m_context, llvm::APFloat(1.0)), "inc1x"); + m_builder->CreateStore(v, addr); + m_prevValue = oldvalue; + } break; + case UnaryOperatorExpr::POS_DEC: + { + auto v = m_builder->CreateFAdd(oldvalue, llvm::ConstantFP::get(*m_context, llvm::APFloat(-1.0)), "dec1x"); + m_builder->CreateStore(v, addr); + m_prevValue = oldvalue; + } break; + } + } else if (dynamic_cast(&expr)) { + auto& e = dynamic_cast(expr); + auto id = dynamic_cast(e.function().get()); + if (!id) throw SCalcError("invalid function call"); + const auto funcname = id->id(); + llvm::Function* f = m_module->getFunction(funcname); + if (!f) throw SCalcError("undefined function '" + funcname + "'"); + if (f->arg_size() != e.parameters()->size()) throw SCalcError("incorrect number of parameters"); + std::vector args; + for (auto& a: *e.parameters()) { + a->accept(*this); + args.push_back(m_prevValue); + } + m_prevValue = m_builder->CreateCall(f, args, funcname + "_result"); + } else if (dynamic_cast(&expr)) { + auto& e = dynamic_cast(expr); + m_prevValue = llvm::ConstantFP::get(*m_context, llvm::APFloat(e.value())); + } +} + +void ASTNodeVisitorLLVMGen::visitExprStat(const ASTNodeExprList& exprlist) +{ + m_prevValue = nullptr; + for (auto& e: exprlist) { + e->accept(*this); + m_prevValue = nullptr; + } +} + +void ASTNodeVisitorLLVMGen::visitStatList(const std::vector>& statlist) +{ + for (auto& stat: statlist) stat->accept(*this); +} + +void ASTNodeVisitorLLVMGen::visitForStat(const ASTNodeExprList* pre, const ASTNodeExpr* cond, const ASTNodeExprList* post, const ASTNodeStat& stat) +{ + if (pre) { + for (auto& e: *pre) + e->accept(*this); + } + auto func = m_builder->GetInsertBlock()->getParent(); + llvm::BasicBlock* after_b = llvm::BasicBlock::Create(*m_context, "loop_after", func); + llvm::BasicBlock* loop_b = llvm::BasicBlock::Create(*m_context, "loop", func, after_b); + llvm::Value* cv = nullptr; + if (cond) { + cond->accept(*this); + cv = m_builder->CreateFCmpONE(m_prevValue, llvm::ConstantFP::get(*m_context, llvm::APFloat(0.0)), "forcond"); + } else { + cv = llvm::ConstantInt::getTrue(*m_context); + } + m_builder->CreateCondBr(cv, loop_b, after_b); + m_builder->SetInsertPoint(loop_b); + stat.accept(*this); + if (post) { + for (auto& e: *post) + e->accept(*this); + } + if (cond) { + cond->accept(*this); + cv = m_builder->CreateFCmpONE(m_prevValue, llvm::ConstantFP::get(*m_context, llvm::APFloat(0.0)), "forcond"); + } else { + cv = llvm::ConstantInt::getTrue(*m_context); + } + m_builder->CreateCondBr(cv, loop_b, after_b); + m_builder->SetInsertPoint(after_b); +} + +void ASTNodeVisitorLLVMGen::visitIfStat(const ASTNodeExpr& cond, const ASTNodeStat& then, const ASTNodeStat* else_) +{ + cond.accept(*this); + auto cmp = m_builder->CreateFCmpONE(m_prevValue, llvm::ConstantFP::get(*m_context, llvm::APFloat(0.0)), "ifcond"); + auto func = m_builder->GetInsertBlock()->getParent(); + llvm::BasicBlock* then_b = llvm::BasicBlock::Create(*m_context, "then", func); + llvm::BasicBlock* else_b = llvm::BasicBlock::Create(*m_context, "else", func); + m_builder->CreateCondBr(cmp, then_b, else_b); + m_builder->SetInsertPoint(then_b); + then.accept(*this); + m_builder->SetInsertPoint(else_b); + if (else_) { + else_->accept(*this); + } + if (!then_b->getTerminator() || !else_b->getTerminator()) { + llvm::BasicBlock* ifcontinue_b = llvm::BasicBlock::Create(*m_context, "ifcontinue", func); + if (!then_b->getTerminator()) { + m_builder->SetInsertPoint(then_b); + m_builder->CreateBr(ifcontinue_b); + } + if (!else_b->getTerminator()) { + m_builder->SetInsertPoint(else_b); + m_builder->CreateBr(ifcontinue_b); + } + m_builder->SetInsertPoint(ifcontinue_b); + m_builder->CreateAnd(cmp, cmp, "damp"); + } +} + +void ASTNodeVisitorLLVMGen::visitBlockStat(const ASTNodeStatList& stat) +{ + m_namedValues.emplace_back(); + for (auto& s: stat) { + s->accept(*this); + } + m_namedValues.pop_back(); +} + +void ASTNodeVisitorLLVMGen::visitReturnStat(const ASTNodeExpr& expr) +{ + expr.accept(*this); + m_builder->CreateRet(m_prevValue); +} + +static llvm::AllocaInst *CreateEntryBlockAlloca(llvm::Function *TheFunction, + llvm::StringRef VarName, llvm::LLVMContext& context) +{ + llvm::IRBuilder<> TmpB(&TheFunction->getEntryBlock(), + TheFunction->getEntryBlock().begin()); + return TmpB.CreateAlloca(llvm::Type::getDoubleTy(context), nullptr, VarName); +} + +void ASTNodeVisitorLLVMGen::visitFuncDef(const std::string& funcname, const std::vector& args, const ASTNodeBlockStat& block) +{ + std::vector argstype(args.size(), llvm::Type::getDoubleTy(*m_context)); + llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getDoubleTy(*m_context), argstype, false); + llvm::Function* f = llvm::Function::Create(ft, llvm::Function::LinkageTypes::InternalLinkage, funcname, m_module.get()); + size_t i=0; + m_namedValues.emplace_back(); + llvm::BasicBlock* fblock = llvm::BasicBlock::Create(*m_context, "entry", f); + m_builder->SetInsertPoint(fblock); + for (auto& a: f->args()) { + auto alloc = CreateEntryBlockAlloca(f, args.at(i), *m_context); + m_builder->CreateStore(&a, alloc); + m_namedValues.back()[args.at(i)] = alloc; + i++; + } + block.accept(*this); + llvm::verifyFunction(*f); + m_namedValues.pop_back(); +} + +void ASTNodeVisitorLLVMGen::visitFuncDecl(const std::string& funcname, const std::vector& args) +{ + std::vector argstype(args.size(), llvm::Type::getDoubleTy(*m_context)); + llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getDoubleTy(*m_context), argstype, false); + llvm::Function* f = llvm::Function::Create(ft, llvm::Function::LinkageTypes::ExternalLinkage, funcname, m_module.get()); +} + +void ASTNodeVisitorLLVMGen::visitCalcUnit(const std::vector& items) +{ + std::vector argstype; + llvm::FunctionType* ft = llvm::FunctionType::get(llvm::Type::getInt32Ty(*m_context), argstype, false); + llvm::Function* f = llvm::Function::Create(ft, llvm::Function::LinkageTypes::ExternalLinkage, "main", m_module.get()); + llvm::BasicBlock* gblock = llvm::BasicBlock::Create(*m_context, "main_entry", f); + for (auto& item: items) { + if (std::holds_alternative>(item)) { + auto& funcdef = std::get>(item); + m_globalscope = false; + m_builder->SetInsertPoint(static_cast(nullptr)); + funcdef->accept(*this); + } else if (std::holds_alternative>(item)) { + auto& funcdecl = std::get>(item); + funcdecl->accept(*this); + } else if (std::holds_alternative>(item)) { + auto& stat = std::get>(item); + m_builder->SetInsertPoint(gblock); + m_globalscope = true; + stat->accept(*this); + } else { + assert(false); + } + } + m_builder->SetInsertPoint(gblock); + m_builder->CreateRet(llvm::ConstantInt::get(llvm::Type::getInt32Ty(*m_context), 0)); + llvm::verifyFunction(*f); +} + +std::string ASTNodeVisitorLLVMGen::codegen() const +{ + std::string buf; + llvm::raw_string_ostream ss(buf); + m_module->print(ss, nullptr); + return buf; +} diff --git a/example/SimpleCalculator/lib/parser.cpp b/example/SimpleCalculator/lib/parser.cpp index aa6bd6c..d9a4445 100644 --- a/example/SimpleCalculator/lib/parser.cpp +++ b/example/SimpleCalculator/lib/parser.cpp @@ -348,6 +348,23 @@ void CalcParser::function_rules() return make_shared(ast); }); + parser( NI(FunctionDecl), { TI(FUNCTION), NI(Expr), TI(LPAREN), ParserChar::beOptional(NI(ArgList)), TI(RPAREN), TI(SEMICOLON) }, + [] (auto c, auto ts) { + assert(ts.size() == 6); + pnonterm(Expr, ASTNodeExpr, 1, func); + + std::shared_ptr argsast; + if (ts[3]) { + pnonterm(ArgList, ASTNodeArgList, 3, argsx); + argsast = argsxast; + } else { + argsast = make_shared(c); + } + + auto ast = make_shared(c, funcast, argsast); + return make_shared(ast); + }); + parser( NI(ArgList), { TI(ID) }, [] (auto c, auto ts) { assert(ts.size() == 1); @@ -418,6 +435,27 @@ void CalcParser::calcunit_rules() unitast->push_function(funcast); return make_shared(unitast); }); + + parser( NI(CalcUnit), { NI(FunctionDecl) }, + [] (auto c, auto ts) { + assert(ts.size() == 1); + pnonterm(FunctionDecl, ASTNodeFunctionDecl, 0, func); + + auto ast = make_shared(c); + ast->push_function_decl(funcast); + + return make_shared(ast); + }); + + parser( NI(CalcUnit), { NI(CalcUnit), NI(FunctionDecl) }, + [] (auto c, auto ts) { + assert(ts.size() == 2); + pnonterm(CalcUnit, ASTNodeCalcUnit, 0, unit); + pnonterm(FunctionDecl, ASTNodeFunctionDecl, 1, func); + + unitast->push_function_decl(funcast); + return make_shared(unitast); + }); } CalcParser::CalcParser(bool execute) diff --git a/example/SimpleCalculator/lib/tokenizer.cpp b/example/SimpleCalculator/lib/tokenizer.cpp index 8e3050c..4407753 100644 --- a/example/SimpleCalculator/lib/tokenizer.cpp +++ b/example/SimpleCalculator/lib/tokenizer.cpp @@ -1,6 +1,5 @@ #include "scalc/token.h" #include "scalc/parser.h" -#include "lexer/lexer_rule.hpp" #include "lexer/lexer_rule_regex.hpp" #include using namespace std; @@ -59,7 +58,7 @@ CalcLexer::CalcLexer() lexer( std::make_unique>( - s2u("//[^\n]*"), + s2u("//[^\n]*\n"), [](auto str, auto info) { return nullptr; }, false, true) diff --git a/example/SimpleCalculator/src/main.cpp b/example/SimpleCalculator/src/main.cpp index 5750cc7..dd92549 100644 --- a/example/SimpleCalculator/src/main.cpp +++ b/example/SimpleCalculator/src/main.cpp @@ -1,5 +1,4 @@ #include "scalc/lexer_parser.h" -#include "scalc/scalc_error.h" #include #include #include @@ -32,15 +31,12 @@ static string trimstring(string str) int main(int argc, char* argv[]) { - CalcLexerParser exec(true); - auto ctx = exec.getContext(); - ctx->set_output(&cout); - if (argc > 3) { usage(argv[0]); return 1; } + string modulename = ""; string input; if (argc == 3) { if (string(argv[1]) == "-f") { @@ -53,6 +49,17 @@ int main(int argc, char* argv[]) ss << ifs.rdbuf(); input = ss.str(); if (input.empty()) input = ";"; + } else if (string(argv[1]) == "-c") { + modulename = argv[2]; + ifstream ifs(argv[2]); + if (!ifs) { + cerr << "File not found: " << argv[2] << endl; + return 1; + } + stringstream ss; + ss << ifs.rdbuf(); + input = ss.str(); + if (input.empty()) input = ";"; } else { usage(argv[0]); return 1; @@ -68,12 +75,27 @@ int main(int argc, char* argv[]) input += ';'; } + CalcLexerParser exec(modulename.empty()); + auto ctx = exec.getContext(); + ctx->set_output(&cout); + const std::string preamble = + "function sin(x);" + "function cos(x);"; if (!input.empty()) { + input = preamble + input; try { for (auto c: input) exec.feed(c); exec.end(); + if (!modulename.empty()) { + const auto ir = exec.genllvm(modulename); + const auto pos = modulename.find_last_of('.'); + const auto of = modulename.substr(0, pos) + ".ll"; + std::fstream fs(of, std::ios::binary | std::ios::out); + fs.write(ir.c_str(), ir.size()); + if (fs.bad()) cerr << "fail write to file: " << std::endl << ir << std::endl; + } return 0; } catch (const runtime_error& e) { cerr << "Error: " << e.what() << endl;