diff --git a/evmjit/libevmjit/Arith256.cpp b/evmjit/libevmjit/Arith256.cpp index caef06bec..48e3307df 100644 --- a/evmjit/libevmjit/Arith256.cpp +++ b/evmjit/libevmjit/Arith256.cpp @@ -200,6 +200,131 @@ llvm::Function* Arith256::getUDiv256Func(llvm::Module& _module) return func; } +llvm::Function* Arith256::getURem256Func(llvm::Module& _module) +{ + static const auto funcName = "evm.urem.i256"; + if (auto func = _module.getFunction(funcName)) + return func; + + auto udivremFunc = getUDivRem256Func(_module); + + auto func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module); + func->setDoesNotThrow(); + func->setDoesNotAccessMemory(); + + auto x = &func->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + + auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func); + auto builder = llvm::IRBuilder<>{bb}; + auto udivrem = builder.CreateCall(udivremFunc, {x, y}); + auto r = builder.CreateExtractElement(udivrem, uint64_t(1)); + builder.CreateRet(r); + + return func; +} + +llvm::Function* Arith256::getSDivRem256Func(llvm::Module& _module) +{ + static const auto funcName = "evm.sdivrem.i256"; + if (auto func = _module.getFunction(funcName)) + return func; + + auto udivremFunc = getUDivRem256Func(_module); + + auto retType = llvm::VectorType::get(Type::Word, 2); + auto func = llvm::Function::Create(llvm::FunctionType::get(retType, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module); + func->setDoesNotThrow(); + func->setDoesNotAccessMemory(); + + auto x = &func->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + + auto bb = llvm::BasicBlock::Create(_module.getContext(), "", func); + auto builder = llvm::IRBuilder<>{bb}; + auto xIsNeg = builder.CreateICmpSLT(x, Constant::get(0)); + auto xNeg = builder.CreateSub(Constant::get(0), x); + auto xAbs = builder.CreateSelect(xIsNeg, xNeg, x); + + auto yIsNeg = builder.CreateICmpSLT(y, Constant::get(0)); + auto yNeg = builder.CreateSub(Constant::get(0), y); + auto yAbs = builder.CreateSelect(yIsNeg, yNeg, y); + + auto res = builder.CreateCall(udivremFunc, {xAbs, yAbs}); + auto qAbs = builder.CreateExtractElement(res, uint64_t(0)); + auto rAbs = builder.CreateExtractElement(res, 1); + + // the reminder has the same sign as dividend + auto rNeg = builder.CreateSub(Constant::get(0), rAbs); + auto r = builder.CreateSelect(xIsNeg, rNeg, rAbs); + + auto qNeg = builder.CreateSub(Constant::get(0), qAbs); + auto xyOpposite = builder.CreateXor(xIsNeg, yIsNeg); + auto q = builder.CreateSelect(xyOpposite, qNeg, qAbs); + + auto ret = builder.CreateInsertElement(llvm::UndefValue::get(retType), q, uint64_t(0)); + ret = builder.CreateInsertElement(ret, r, 1); + builder.CreateRet(ret); + + return func; +} + +llvm::Function* Arith256::getSDiv256Func(llvm::Module& _module) +{ + static const auto funcName = "evm.sdiv.i256"; + if (auto func = _module.getFunction(funcName)) + return func; + + auto sdivremFunc = getSDivRem256Func(_module); + + auto func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module); + func->setDoesNotThrow(); + func->setDoesNotAccessMemory(); + + auto x = &func->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + + auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func); + auto builder = llvm::IRBuilder<>{bb}; + auto sdivrem = builder.CreateCall(sdivremFunc, {x, y}); + auto q = builder.CreateExtractElement(sdivrem, uint64_t(0)); + builder.CreateRet(q); + + return func; +} + +llvm::Function* Arith256::getSRem256Func(llvm::Module& _module) +{ + static const auto funcName = "evm.srem.i256"; + if (auto func = _module.getFunction(funcName)) + return func; + + auto sdivremFunc = getSDivRem256Func(_module); + + auto func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module); + func->setDoesNotThrow(); + func->setDoesNotAccessMemory(); + + auto x = &func->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + + auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func); + auto builder = llvm::IRBuilder<>{bb}; + auto sdivrem = builder.CreateCall(sdivremFunc, {x, y}); + auto r = builder.CreateExtractElement(sdivrem, uint64_t(1)); + builder.CreateRet(r); + + return func; +} + llvm::Function* Arith256::getMul512Func() { auto& func = m_mul512; @@ -462,61 +587,6 @@ llvm::Function* Arith256::getMulModFunc() return m_mulmod; } -std::pair Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2) -{ - if (auto c1 = llvm::dyn_cast(_arg1)) - { - if (auto c2 = llvm::dyn_cast(_arg2)) - { - if (!c2->getValue()) - return std::make_pair(Constant::get(0), Constant::get(0)); - auto div = Constant::get(c1->getValue().udiv(c2->getValue())); - auto mod = Constant::get(c1->getValue().urem(c2->getValue())); - return std::make_pair(div, mod); - } - } - - auto r = createCall(getDivFunc(Type::Word), {_arg1, _arg2}); - auto div = m_builder.CreateExtractElement(r, uint64_t(0), "div"); - auto mod = m_builder.CreateExtractElement(r, 1, "mod"); - return std::make_pair(div, mod); -} - -std::pair Arith256::sdiv(llvm::Value* _x, llvm::Value* _y) -{ - if (auto c1 = llvm::dyn_cast(_x)) - { - if (auto c2 = llvm::dyn_cast(_y)) - { - if (!c2->getValue()) - return std::make_pair(Constant::get(0), Constant::get(0)); - auto div = Constant::get(c1->getValue().sdiv(c2->getValue())); - auto mod = Constant::get(c1->getValue().srem(c2->getValue())); - return std::make_pair(div, mod); - } - } - - auto xIsNeg = m_builder.CreateICmpSLT(_x, Constant::get(0)); - auto xNeg = m_builder.CreateSub(Constant::get(0), _x); - auto xAbs = m_builder.CreateSelect(xIsNeg, xNeg, _x); - - auto yIsNeg = m_builder.CreateICmpSLT(_y, Constant::get(0)); - auto yNeg = m_builder.CreateSub(Constant::get(0), _y); - auto yAbs = m_builder.CreateSelect(yIsNeg, yNeg, _y); - - auto res = div(xAbs, yAbs); - - // the reminder has the same sign as dividend - auto rNeg = m_builder.CreateSub(Constant::get(0), res.second); - res.second = m_builder.CreateSelect(xIsNeg, rNeg, res.second); - - auto qNeg = m_builder.CreateSub(Constant::get(0), res.first); - auto xyOpposite = m_builder.CreateXor(xIsNeg, yIsNeg); - res.first = m_builder.CreateSelect(xyOpposite, qNeg, res.first); - - return res; -} - llvm::Value* Arith256::exp(llvm::Value* _arg1, llvm::Value* _arg2) { // while (e != 0) { diff --git a/evmjit/libevmjit/Arith256.h b/evmjit/libevmjit/Arith256.h index 3ee016073..aeea830db 100644 --- a/evmjit/libevmjit/Arith256.h +++ b/evmjit/libevmjit/Arith256.h @@ -14,8 +14,6 @@ class Arith256 : public CompilerHelper public: Arith256(llvm::IRBuilder<>& _builder); - std::pair div(llvm::Value* _arg1, llvm::Value* _arg2); - std::pair sdiv(llvm::Value* _arg1, llvm::Value* _arg2); llvm::Value* exp(llvm::Value* _arg1, llvm::Value* _arg2); llvm::Value* mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3); llvm::Value* addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3); @@ -24,7 +22,11 @@ public: static llvm::Function* getMulFunc(llvm::Module& _module); static llvm::Function* getUDiv256Func(llvm::Module& _module); + static llvm::Function* getURem256Func(llvm::Module& _module); static llvm::Function* getUDivRem256Func(llvm::Module& _module); + static llvm::Function* getSDiv256Func(llvm::Module& _module); + static llvm::Function* getSRem256Func(llvm::Module& _module); + static llvm::Function* getSDivRem256Func(llvm::Module& _module); private: llvm::Function* getMul512Func(); diff --git a/evmjit/libevmjit/Compiler.cpp b/evmjit/libevmjit/Compiler.cpp index 2e4976c78..7ae3e067b 100644 --- a/evmjit/libevmjit/Compiler.cpp +++ b/evmjit/libevmjit/Compiler.cpp @@ -288,8 +288,8 @@ void Compiler::compileBasicBlock(BasicBlock& _basicBlock, RuntimeManager& _runti { auto lhs = stack.pop(); auto rhs = stack.pop(); - auto res = _arith.sdiv(lhs, rhs); - stack.push(res.first); + auto res = m_builder.CreateSDiv(lhs, rhs); + stack.push(res); break; } @@ -297,8 +297,8 @@ void Compiler::compileBasicBlock(BasicBlock& _basicBlock, RuntimeManager& _runti { auto lhs = stack.pop(); auto rhs = stack.pop(); - auto res = _arith.div(lhs, rhs); - stack.push(res.second); + auto res = m_builder.CreateURem(lhs, rhs); + stack.push(res); break; } @@ -306,8 +306,8 @@ void Compiler::compileBasicBlock(BasicBlock& _basicBlock, RuntimeManager& _runti { auto lhs = stack.pop(); auto rhs = stack.pop(); - auto res = _arith.sdiv(lhs, rhs); - stack.push(res.second); + auto res = m_builder.CreateSRem(lhs, rhs); + stack.push(res); break; } diff --git a/evmjit/libevmjit/Optimizer.cpp b/evmjit/libevmjit/Optimizer.cpp index 982f1cefe..8eaab9a99 100644 --- a/evmjit/libevmjit/Optimizer.cpp +++ b/evmjit/libevmjit/Optimizer.cpp @@ -69,6 +69,18 @@ bool LowerEVMPass::runOnBasicBlock(llvm::BasicBlock& _bb) case llvm::Instruction::UDiv: func = Arith256::getUDiv256Func(*module); break; + + case llvm::Instruction::URem: + func = Arith256::getURem256Func(*module); + break; + + case llvm::Instruction::SDiv: + func = Arith256::getSDiv256Func(*module); + break; + + case llvm::Instruction::SRem: + func = Arith256::getSRem256Func(*module); + break; } }