diff --git a/libevmjit/Arith256.cpp b/libevmjit/Arith256.cpp index 941a54fbf..862c8452e 100644 --- a/libevmjit/Arith256.cpp +++ b/libevmjit/Arith256.cpp @@ -29,10 +29,8 @@ Arith256::Arith256(llvm::IRBuilder<>& _builder) : using Linkage = GlobalValue::LinkageTypes; llvm::Type* arg2Types[] = {Type::WordPtr, Type::WordPtr, Type::WordPtr}; - llvm::Type* arg3Types[] = {Type::WordPtr, Type::WordPtr, Type::WordPtr, Type::WordPtr}; m_mul = Function::Create(FunctionType::get(Type::Void, arg2Types, false), Linkage::ExternalLinkage, "arith_mul", getModule()); - m_addmod = Function::Create(FunctionType::get(Type::Void, arg3Types, false), Linkage::ExternalLinkage, "arith_addmod", getModule()); } void Arith256::debug(llvm::Value* _value, char _c) @@ -210,6 +208,36 @@ llvm::Function* Arith256::getExpFunc() return m_exp; } +llvm::Function* Arith256::getAddModFunc() +{ + if (!m_addmod) + { + auto i512Ty = m_builder.getIntNTy(512); + llvm::Type* argTypes[] = {Type::Word, Type::Word, Type::Word}; + m_addmod = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "addmod", getModule()); + + auto x = &m_addmod->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + auto mod = y->getNextNode(); + mod->setName("m"); + + InsertPointGuard guard{m_builder}; + + auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), {}, m_addmod); + m_builder.SetInsertPoint(entryBB); + auto x512 = m_builder.CreateZExt(x, i512Ty, "x512"); + auto y512 = m_builder.CreateZExt(y, i512Ty, "y512"); + auto m512 = m_builder.CreateZExt(mod, i512Ty, "m512"); + auto s = m_builder.CreateAdd(x512, y512, "s"); + auto d = createCall(getDivFunc(i512Ty), {s, m512}); + auto r = m_builder.CreateExtractValue(d, 1, "r"); + m_builder.CreateRet(r); + } + return m_addmod; +} + llvm::Function* Arith256::getMulModFunc() { if (!m_mulmod) @@ -247,7 +275,6 @@ llvm::Function* Arith256::getMulModFunc() return m_mulmod; } - llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2) { m_builder.CreateStore(_arg1, m_arg1); @@ -256,15 +283,6 @@ llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::V return m_builder.CreateLoad(m_result); } -llvm::Value* Arith256::ternaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) -{ - m_builder.CreateStore(_arg1, m_arg1); - m_builder.CreateStore(_arg2, m_arg2); - m_builder.CreateStore(_arg3, m_arg3); - m_builder.CreateCall4(_op, m_arg1, m_arg2, m_arg3, m_result); - return m_builder.CreateLoad(m_result); -} - llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2) { return binaryOp(m_mul, _arg1, _arg2); @@ -307,7 +325,7 @@ llvm::Value* Arith256::exp(llvm::Value* _arg1, llvm::Value* _arg2) llvm::Value* Arith256::addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) { - return ternaryOp(m_addmod, _arg1, _arg2, _arg3); + return createCall(getAddModFunc(), {_arg1, _arg2, _arg3}); } llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) @@ -402,46 +420,14 @@ namespace return {lo, (uint128)mid, hi}; } - - inline void mul(i256* x, i256* y) - { - auto a = (uint256*) x; - auto b = (uint256*) y; - *a = mul(*a, *b); - } - - bool isZero(i256 const* _n) - { - return _n->a == 0 && _n->b == 0 && _n->c == 0 && _n->d == 0; - } - - const auto nLimbs = sizeof(i256) / sizeof(mp_limb_t); - - int countLimbs(i256 const* _n) - { - static const auto limbsInWord = sizeof(_n->a) / sizeof(mp_limb_t); - static_assert(limbsInWord == 1, "E?"); - - int l = nLimbs; - if (_n->d != 0) return l; - l -= limbsInWord; - if (_n->c != 0) return l; - l -= limbsInWord; - if (_n->b != 0) return l; - l -= limbsInWord; - if (_n->a != 0) return l; - return 0; - } } } } } - extern "C" { - using namespace dev::eth::jit; EXPORT void debug(uint64_t a, uint64_t b, uint64_t c, uint64_t d, char z) @@ -458,24 +444,4 @@ extern "C" { *o_result = mul512(*_arg1, *_arg2); } - - EXPORT void arith_addmod(i256* _arg1, i256* _arg2, i256* _arg3, i256* o_result) - { - *o_result = {}; - if (isZero(_arg3)) - return; - - mpz_t x{nLimbs, countLimbs(_arg1), reinterpret_cast(_arg1)}; - mpz_t y{nLimbs, countLimbs(_arg2), reinterpret_cast(_arg2)}; - mpz_t m{nLimbs, countLimbs(_arg3), reinterpret_cast(_arg3)}; - mpz_t z{nLimbs, 0, reinterpret_cast(o_result)}; - static mp_limb_t s_limbs[nLimbs + 1] = {}; - static mpz_t s{nLimbs + 1, 0, &s_limbs[0]}; - - mpz_add(s, x, y); - mpz_tdiv_r(z, s, m); - } - } - - diff --git a/libevmjit/Arith256.h b/libevmjit/Arith256.h index ae87d11f2..5852137f8 100644 --- a/libevmjit/Arith256.h +++ b/libevmjit/Arith256.h @@ -26,17 +26,17 @@ public: private: llvm::Function* getDivFunc(llvm::Type* _type); llvm::Function* getExpFunc(); + llvm::Function* getAddModFunc(); llvm::Function* getMulModFunc(); llvm::Value* binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2); - llvm::Value* ternaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3); llvm::Function* m_mul; - llvm::Function* m_addmod; llvm::Function* m_div = nullptr; llvm::Function* m_div512 = nullptr; llvm::Function* m_exp = nullptr; + llvm::Function* m_addmod = nullptr; llvm::Function* m_mulmod = nullptr; llvm::Function* m_debug = nullptr;