From 463025e15de5181a9661e6e67876588908bb8eb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Bylica?= Date: Tue, 28 Apr 2015 17:50:01 +0200 Subject: [PATCH] Implementation of MUL workaround in "LLVM pass" way. --- libevmjit/Arith256.cpp | 125 ++++++++++++++++------------------ libevmjit/Arith256.h | 5 +- libevmjit/Compiler.cpp | 2 +- libevmjit/ExecutionEngine.cpp | 2 + libevmjit/Optimizer.cpp | 59 ++++++++++++++++ libevmjit/Optimizer.h | 2 + 6 files changed, 124 insertions(+), 71 deletions(-) diff --git a/libevmjit/Arith256.cpp b/libevmjit/Arith256.cpp index c7608c5b8..d8e4c5138 100644 --- a/libevmjit/Arith256.cpp +++ b/libevmjit/Arith256.cpp @@ -4,6 +4,7 @@ #include #include "preprocessor/llvm_includes_start.h" +#include #include #include "preprocessor/llvm_includes_end.h" @@ -32,57 +33,56 @@ void Arith256::debug(llvm::Value* _value, char _c) createCall(m_debug, {m_builder.CreateZExtOrTrunc(_value, Type::Word), m_builder.getInt8(_c)}); } -llvm::Function* Arith256::getMulFunc() +llvm::Function* Arith256::getMulFunc(llvm::Module& _module) { - auto& func = m_mul; - if (!func) - { - llvm::Type* argTypes[] = {Type::Word, Type::Word}; - func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "mul", getModule()); - func->setDoesNotThrow(); - func->setDoesNotAccessMemory(); - - auto x = &func->getArgumentList().front(); - x->setName("x"); - auto y = x->getNextNode(); - y->setName("y"); - - InsertPointGuard guard{m_builder}; - auto bb = llvm::BasicBlock::Create(m_builder.getContext(), {}, func); - m_builder.SetInsertPoint(bb); - auto i64 = Type::Size; - auto i128 = m_builder.getIntNTy(128); - auto i256 = Type::Word; - auto c64 = Constant::get(64); - auto c128 = Constant::get(128); - auto c192 = Constant::get(192); - - auto x_lo = m_builder.CreateTrunc(x, i64, "x.lo"); - auto y_lo = m_builder.CreateTrunc(y, i64, "y.lo"); - auto x_mi = m_builder.CreateTrunc(m_builder.CreateLShr(x, c64), i64); - auto y_mi = m_builder.CreateTrunc(m_builder.CreateLShr(y, c64), i64); - auto x_hi = m_builder.CreateTrunc(m_builder.CreateLShr(x, c128), i128); - auto y_hi = m_builder.CreateTrunc(m_builder.CreateLShr(y, c128), i128); - - auto t1 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), m_builder.CreateZExt(y_lo, i128)); - auto t2 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), m_builder.CreateZExt(y_mi, i128)); - auto t3 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), y_hi); - auto t4 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), m_builder.CreateZExt(y_lo, i128)); - auto t5 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), m_builder.CreateZExt(y_mi, i128)); - auto t6 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), y_hi); - auto t7 = m_builder.CreateMul(x_hi, m_builder.CreateZExt(y_lo, i128)); - auto t8 = m_builder.CreateMul(x_hi, m_builder.CreateZExt(y_mi, i128)); - - auto p = m_builder.CreateZExt(t1, i256); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t2, i256), c64)); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t3, i256), c128)); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t4, i256), c64)); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t5, i256), c128)); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t6, i256), c192)); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t7, i256), c128)); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t8, i256), c192)); - m_builder.CreateRet(p); - } + static const auto funcName = "evm.mul.i256"; + if (auto func = _module.getFunction(funcName)) + return func; + + llvm::Type* argTypes[] = {Type::Word, Type::Word}; + auto func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, funcName, &_module); + func->setDoesNotThrow(); + func->setDoesNotAccessMemory(); + + auto x = &func->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + + auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func); + auto builder = llvm::IRBuilder<>{bb}; + auto i64 = Type::Size; + auto i128 = builder.getIntNTy(128); + auto i256 = Type::Word; + auto c64 = Constant::get(64); + auto c128 = Constant::get(128); + auto c192 = Constant::get(192); + + auto x_lo = builder.CreateTrunc(x, i64, "x.lo"); + auto y_lo = builder.CreateTrunc(y, i64, "y.lo"); + auto x_mi = builder.CreateTrunc(builder.CreateLShr(x, c64), i64); + auto y_mi = builder.CreateTrunc(builder.CreateLShr(y, c64), i64); + auto x_hi = builder.CreateTrunc(builder.CreateLShr(x, c128), i128); + auto y_hi = builder.CreateTrunc(builder.CreateLShr(y, c128), i128); + + auto t1 = builder.CreateMul(builder.CreateZExt(x_lo, i128), builder.CreateZExt(y_lo, i128)); + auto t2 = builder.CreateMul(builder.CreateZExt(x_lo, i128), builder.CreateZExt(y_mi, i128)); + auto t3 = builder.CreateMul(builder.CreateZExt(x_lo, i128), y_hi); + auto t4 = builder.CreateMul(builder.CreateZExt(x_mi, i128), builder.CreateZExt(y_lo, i128)); + auto t5 = builder.CreateMul(builder.CreateZExt(x_mi, i128), builder.CreateZExt(y_mi, i128)); + auto t6 = builder.CreateMul(builder.CreateZExt(x_mi, i128), y_hi); + auto t7 = builder.CreateMul(x_hi, builder.CreateZExt(y_lo, i128)); + auto t8 = builder.CreateMul(x_hi, builder.CreateZExt(y_mi, i128)); + + auto p = builder.CreateZExt(t1, i256); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t2, i256), c64)); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t3, i256), c128)); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t4, i256), c64)); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t5, i256), c128)); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t6, i256), c192)); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t7, i256), c128)); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t8, i256), c192)); + builder.CreateRet(p); return func; } @@ -112,10 +112,11 @@ llvm::Function* Arith256::getMul512Func() auto x_hi = m_builder.CreateZExt(m_builder.CreateTrunc(m_builder.CreateLShr(x, Constant::get(128)), i128, "x.hi"), i256); auto y_hi = m_builder.CreateZExt(m_builder.CreateTrunc(m_builder.CreateLShr(y, Constant::get(128)), i128, "y.hi"), i256); - auto t1 = createCall(getMulFunc(), {x_lo, y_lo}); - auto t2 = createCall(getMulFunc(), {x_lo, y_hi}); - auto t3 = createCall(getMulFunc(), {x_hi, y_lo}); - auto t4 = createCall(getMulFunc(), {x_hi, y_hi}); + auto mul256Func = getMulFunc(*getModule()); + auto t1 = createCall(mul256Func, {x_lo, y_lo}); + auto t2 = createCall(mul256Func, {x_lo, y_hi}); + auto t3 = createCall(mul256Func, {x_hi, y_lo}); + auto t4 = createCall(mul256Func, {x_hi, y_hi}); auto p = m_builder.CreateZExt(t1, i512); p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t2, i512), m_builder.getIntN(512, 128))); @@ -260,14 +261,15 @@ llvm::Function* Arith256::getExpFunc() m_builder.CreateCondBr(eOdd, updateBB, continueBB); m_builder.SetInsertPoint(updateBB); - auto r0 = createCall(getMulFunc(), {r, b}); + auto mul256Func = getMulFunc(*getModule()); + auto r0 = createCall(mul256Func, {r, b}); m_builder.CreateBr(continueBB); m_builder.SetInsertPoint(continueBB); auto r1 = m_builder.CreatePHI(Type::Word, 2, "r1"); r1->addIncoming(r, bodyBB); r1->addIncoming(r0, updateBB); - auto b1 = createCall(getMulFunc(), {b, b}); + auto b1 = createCall(mul256Func, {b, b}); auto e1 = m_builder.CreateLShr(e, Constant::get(1), "e1"); m_builder.CreateBr(headerBB); @@ -347,17 +349,6 @@ llvm::Function* Arith256::getMulModFunc() return m_mulmod; } -llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2) -{ - if (auto c1 = llvm::dyn_cast(_arg1)) - { - if (auto c2 = llvm::dyn_cast(_arg2)) - return Constant::get(c1->getValue() * c2->getValue()); - } - - return createCall(getMulFunc(), {_arg1, _arg2}); -} - std::pair Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2) { if (auto c1 = llvm::dyn_cast(_arg1)) diff --git a/libevmjit/Arith256.h b/libevmjit/Arith256.h index 2513ca568..a5762429d 100644 --- a/libevmjit/Arith256.h +++ b/libevmjit/Arith256.h @@ -14,7 +14,6 @@ class Arith256 : public CompilerHelper public: Arith256(llvm::IRBuilder<>& _builder); - llvm::Value* mul(llvm::Value* _arg1, llvm::Value* _arg2); std::pair div(llvm::Value* _arg1, llvm::Value* _arg2); std::pair sdiv(llvm::Value* _arg1, llvm::Value* _arg2); llvm::Value* exp(llvm::Value* _arg1, llvm::Value* _arg2); @@ -23,15 +22,15 @@ public: void debug(llvm::Value* _value, char _c); + static llvm::Function* getMulFunc(llvm::Module& _module); + private: - llvm::Function* getMulFunc(); llvm::Function* getMul512Func(); llvm::Function* getDivFunc(llvm::Type* _type); llvm::Function* getExpFunc(); llvm::Function* getAddModFunc(); llvm::Function* getMulModFunc(); - llvm::Function* m_mul = nullptr; llvm::Function* m_mul512 = nullptr; llvm::Function* m_div = nullptr; llvm::Function* m_div512 = nullptr; diff --git a/libevmjit/Compiler.cpp b/libevmjit/Compiler.cpp index d7240a415..a2f58c770 100644 --- a/libevmjit/Compiler.cpp +++ b/libevmjit/Compiler.cpp @@ -270,7 +270,7 @@ void Compiler::compileBasicBlock(BasicBlock& _basicBlock, RuntimeManager& _runti { auto lhs = stack.pop(); auto rhs = stack.pop(); - auto res = _arith.mul(lhs, rhs); + auto res = m_builder.CreateMul(lhs, rhs); stack.push(res); break; } diff --git a/libevmjit/ExecutionEngine.cpp b/libevmjit/ExecutionEngine.cpp index fba9cfd96..b7cf41010 100644 --- a/libevmjit/ExecutionEngine.cpp +++ b/libevmjit/ExecutionEngine.cpp @@ -162,6 +162,8 @@ ReturnCode ExecutionEngine::run(RuntimeData* _data, Env* _env) listener->stateChanged(ExecState::Optimization); optimize(*module); } + + prepare(*module); } if (g_dump) module->dump(); diff --git a/libevmjit/Optimizer.cpp b/libevmjit/Optimizer.cpp index df88b4df8..4913dcaa0 100644 --- a/libevmjit/Optimizer.cpp +++ b/libevmjit/Optimizer.cpp @@ -1,11 +1,16 @@ #include "Optimizer.h" #include "preprocessor/llvm_includes_start.h" +#include +#include #include #include #include #include "preprocessor/llvm_includes_end.h" +#include "Arith256.h" +#include "Type.h" + namespace dev { namespace eth @@ -24,6 +29,60 @@ bool optimize(llvm::Module& _module) return pm.run(_module); } +namespace +{ + +class LowerEVMPass : public llvm::BasicBlockPass +{ + static char ID; + + bool m_mulFuncNeeded = false; + +public: + LowerEVMPass(): + llvm::BasicBlockPass(ID) + {} + + virtual bool runOnBasicBlock(llvm::BasicBlock& _bb) override; + + virtual bool doFinalization(llvm::Module& _module) override; +}; + +char LowerEVMPass::ID = 0; + +bool LowerEVMPass::runOnBasicBlock(llvm::BasicBlock& _bb) +{ + auto modified = false; + auto module = _bb.getParent()->getParent(); + for (auto&& inst : _bb) + { + if (inst.getOpcode() == llvm::Instruction::Mul) + { + if (inst.getType() == Type::Word) + { + auto call = llvm::CallInst::Create(Arith256::getMulFunc(*module), {inst.getOperand(0), inst.getOperand(1)}, "", &inst); + inst.replaceAllUsesWith(call); + modified = true; + } + } + } + return modified; +} + +bool LowerEVMPass::doFinalization(llvm::Module&) +{ + return false; +} + +} + +bool prepare(llvm::Module& _module) +{ + auto pm = llvm::legacy::PassManager{}; + pm.add(new LowerEVMPass{}); + return pm.run(_module); +} + } } } diff --git a/libevmjit/Optimizer.h b/libevmjit/Optimizer.h index 4a3147a7f..4b7ab7e9a 100644 --- a/libevmjit/Optimizer.h +++ b/libevmjit/Optimizer.h @@ -14,6 +14,8 @@ namespace jit bool optimize(llvm::Module& _module); +bool prepare(llvm::Module& _module); + } } }