diff --git a/evmjit/libevmjit/Arith256.cpp b/evmjit/libevmjit/Arith256.cpp index 220aa6f05..47701462c 100644 --- a/evmjit/libevmjit/Arith256.cpp +++ b/evmjit/libevmjit/Arith256.cpp @@ -408,18 +408,49 @@ llvm::Function* Arith256::getMulModFunc() llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2) { + if (auto c1 = llvm::dyn_cast(_arg1)) + { + if (auto c2 = llvm::dyn_cast(_arg2)) + return Constant::get(c1->getValue() * c2->getValue()); + } + return createCall(getMulFunc(), {_arg1, _arg2}); } std::pair Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2) { - auto div = m_builder.CreateExtractValue(createCall(getDivFunc(Type::Word), {_arg1, _arg2}), 0, "div"); - auto mod = m_builder.CreateExtractValue(createCall(getDivFunc(Type::Word), {_arg1, _arg2}), 1, "mod"); + if (auto c1 = llvm::dyn_cast(_arg1)) + { + if (auto c2 = llvm::dyn_cast(_arg2)) + { + if (!c2->getValue()) + return std::make_pair(Constant::get(0), Constant::get(0)); + auto div = Constant::get(c1->getValue().udiv(c2->getValue())); + auto mod = Constant::get(c1->getValue().urem(c2->getValue())); + return std::make_pair(div, mod); + } + } + + auto r = createCall(getDivFunc(Type::Word), {_arg1, _arg2}); + auto div = m_builder.CreateExtractValue(r, 0, "div"); + auto mod = m_builder.CreateExtractValue(r, 1, "mod"); return std::make_pair(div, mod); } std::pair Arith256::sdiv(llvm::Value* _x, llvm::Value* _y) { + if (auto c1 = llvm::dyn_cast(_x)) + { + if (auto c2 = llvm::dyn_cast(_y)) + { + if (!c2->getValue()) + return std::make_pair(Constant::get(0), Constant::get(0)); + auto div = Constant::get(c1->getValue().sdiv(c2->getValue())); + auto mod = Constant::get(c1->getValue().srem(c2->getValue())); + return std::make_pair(div, mod); + } + } + auto xIsNeg = m_builder.CreateICmpSLT(_x, Constant::get(0)); auto xNeg = m_builder.CreateSub(Constant::get(0), _x); auto xAbs = m_builder.CreateSelect(xIsNeg, xNeg, _x); @@ -443,16 +474,71 @@ std::pair Arith256::sdiv(llvm::Value* _x, llvm::Valu llvm::Value* Arith256::exp(llvm::Value* _arg1, llvm::Value* _arg2) { + // while (e != 0) { + // if (e % 2 == 1) + // r *= b; + // b *= b; + // e /= 2; + // } + + if (auto c1 = llvm::dyn_cast(_arg1)) + { + if (auto c2 = llvm::dyn_cast(_arg2)) + { + auto b = c1->getValue(); + auto e = c2->getValue(); + auto r = llvm::APInt{256, 1}; + while (e != 0) + { + if (e[0]) + r *= b; + b *= b; + e = e.lshr(1); + } + return Constant::get(r); + } + } + return createCall(getExpFunc(), {_arg1, _arg2}); } llvm::Value* Arith256::addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) { + if (auto c1 = llvm::dyn_cast(_arg1)) + { + if (auto c2 = llvm::dyn_cast(_arg2)) + { + if (auto c3 = llvm::dyn_cast(_arg3)) + { + if (!c3->getValue()) + return Constant::get(0); + auto s = c1->getValue().zext(256+64) + c2->getValue().zext(256+64); + auto r = s.urem(c3->getValue().zext(256+64)).trunc(256); + return Constant::get(r); + } + } + } + return createCall(getAddModFunc(), {_arg1, _arg2, _arg3}); } llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) { + if (auto c1 = llvm::dyn_cast(_arg1)) + { + if (auto c2 = llvm::dyn_cast(_arg2)) + { + if (auto c3 = llvm::dyn_cast(_arg3)) + { + if (!c3->getValue()) + return Constant::get(0); + auto p = c1->getValue().zext(512) * c2->getValue().zext(512); + auto r = p.urem(c3->getValue().zext(512)).trunc(256); + return Constant::get(r); + } + } + } + return createCall(getMulModFunc(), {_arg1, _arg2, _arg3}); }