diff --git a/libevmjit/Arith256.cpp b/libevmjit/Arith256.cpp index a550b2e3f..941a54fbf 100644 --- a/libevmjit/Arith256.cpp +++ b/libevmjit/Arith256.cpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace dev { @@ -32,7 +33,6 @@ Arith256::Arith256(llvm::IRBuilder<>& _builder) : m_mul = Function::Create(FunctionType::get(Type::Void, arg2Types, false), Linkage::ExternalLinkage, "arith_mul", getModule()); m_addmod = Function::Create(FunctionType::get(Type::Void, arg3Types, false), Linkage::ExternalLinkage, "arith_addmod", getModule()); - m_mulmod = Function::Create(FunctionType::get(Type::Void, arg3Types, false), Linkage::ExternalLinkage, "arith_mulmod", getModule()); } void Arith256::debug(llvm::Value* _value, char _c) @@ -45,64 +45,70 @@ void Arith256::debug(llvm::Value* _value, char _c) createCall(m_debug, {_value, m_builder.getInt8(_c)}); } -llvm::Function* Arith256::getDivFunc() +llvm::Function* Arith256::getDivFunc(llvm::Type* _type) { - if (!m_div) + auto& func = _type == Type::Word ? m_div : m_div512; + + if (!func) { // Based of "Improved shift divisor algorithm" from "Software Integer Division" by Microsoft Research // The following algorithm also handles divisor of value 0 returning 0 for both quotient and reminder - llvm::Type* argTypes[] = {Type::Word, Type::Word}; + llvm::Type* argTypes[] = {_type, _type}; auto retType = llvm::StructType::get(m_builder.getContext(), llvm::ArrayRef{argTypes}); - m_div = llvm::Function::Create(llvm::FunctionType::get(retType, argTypes, false), llvm::Function::PrivateLinkage, "arith.div", getModule()); + auto funcName = _type == Type::Word ? "div" : "div512"; + func = llvm::Function::Create(llvm::FunctionType::get(retType, argTypes, false), llvm::Function::PrivateLinkage, funcName, getModule()); + + auto zero = llvm::ConstantInt::get(_type, 0); + auto one = llvm::ConstantInt::get(_type, 1); - auto x = &m_div->getArgumentList().front(); + auto x = &func->getArgumentList().front(); x->setName("x"); auto yArg = x->getNextNode(); yArg->setName("y"); InsertPointGuard guard{m_builder}; - auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), "Entry", m_div); - auto mainBB = llvm::BasicBlock::Create(m_builder.getContext(), "Main", m_div); - auto loopBB = llvm::BasicBlock::Create(m_builder.getContext(), "Loop", m_div); - auto continueBB = llvm::BasicBlock::Create(m_builder.getContext(), "Continue", m_div); - auto returnBB = llvm::BasicBlock::Create(m_builder.getContext(), "Return", m_div); + auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), "Entry", func); + auto mainBB = llvm::BasicBlock::Create(m_builder.getContext(), "Main", func); + auto loopBB = llvm::BasicBlock::Create(m_builder.getContext(), "Loop", func); + auto continueBB = llvm::BasicBlock::Create(m_builder.getContext(), "Continue", func); + auto returnBB = llvm::BasicBlock::Create(m_builder.getContext(), "Return", func); m_builder.SetInsertPoint(entryBB); - auto yNonZero = m_builder.CreateICmpNE(yArg, Constant::get(0)); + auto yNonZero = m_builder.CreateICmpNE(yArg, zero); auto yLEx = m_builder.CreateICmpULE(yArg, x); - auto r0 = m_builder.CreateSelect(yNonZero, x, Constant::get(0), "r0"); + auto r0 = m_builder.CreateSelect(yNonZero, x, zero, "r0"); m_builder.CreateCondBr(m_builder.CreateAnd(yLEx, yNonZero), mainBB, returnBB); m_builder.SetInsertPoint(mainBB); - auto ctlzIntr = llvm::Intrinsic::getDeclaration(getModule(), llvm::Intrinsic::ctlz, Type::Word); + auto ctlzIntr = llvm::Intrinsic::getDeclaration(getModule(), llvm::Intrinsic::ctlz, _type); // both y and r are non-zero auto yLz = m_builder.CreateCall2(ctlzIntr, yArg, m_builder.getInt1(true), "y.lz"); auto rLz = m_builder.CreateCall2(ctlzIntr, r0, m_builder.getInt1(true), "r.lz"); auto i0 = m_builder.CreateNUWSub(yLz, rLz, "i0"); - auto shlBy0 = m_builder.CreateICmpEQ(i0, Constant::get(0)); + auto shlBy0 = m_builder.CreateICmpEQ(i0, zero); auto y0 = m_builder.CreateShl(yArg, i0); y0 = m_builder.CreateSelect(shlBy0, yArg, y0, "y0"); // Workaround for LLVM bug: shl by 0 produces wrong result m_builder.CreateBr(loopBB); m_builder.SetInsertPoint(loopBB); - auto yPhi = m_builder.CreatePHI(Type::Word, 2, "y.phi"); - auto rPhi = m_builder.CreatePHI(Type::Word, 2, "r.phi"); - auto iPhi = m_builder.CreatePHI(Type::Word, 2, "i.phi"); - auto qPhi = m_builder.CreatePHI(Type::Word, 2, "q.phi"); + auto yPhi = m_builder.CreatePHI(_type, 2, "y.phi"); + auto rPhi = m_builder.CreatePHI(_type, 2, "r.phi"); + auto iPhi = m_builder.CreatePHI(_type, 2, "i.phi"); + auto qPhi = m_builder.CreatePHI(_type, 2, "q.phi"); auto rUpdate = m_builder.CreateNUWSub(rPhi, yPhi); - auto qUpdate = m_builder.CreateOr(qPhi, Constant::get(1)); // q += 1, q lowest bit is 0 + auto qUpdate = m_builder.CreateOr(qPhi, one); // q += 1, q lowest bit is 0 auto rGEy = m_builder.CreateICmpUGE(rPhi, yPhi); auto r1 = m_builder.CreateSelect(rGEy, rUpdate, rPhi, "r1"); auto q1 = m_builder.CreateSelect(rGEy, qUpdate, qPhi, "q"); - auto iZero = m_builder.CreateICmpEQ(iPhi, Constant::get(0)); + auto iZero = m_builder.CreateICmpEQ(iPhi, zero); m_builder.CreateCondBr(iZero, returnBB, continueBB); m_builder.SetInsertPoint(continueBB); - auto i2 = m_builder.CreateNUWSub(iPhi, Constant::get(1)); - auto q2 = m_builder.CreateShl(q1, Constant::get(1)); - auto y2 = m_builder.CreateUDiv(yPhi, Constant::get(2)); + auto i2 = m_builder.CreateNUWSub(iPhi, one); + auto q2 = m_builder.CreateShl(q1, one); + auto y2 = m_builder.CreateLShr(yPhi, one); m_builder.CreateBr(loopBB); yPhi->addIncoming(y0, mainBB); @@ -111,21 +117,21 @@ llvm::Function* Arith256::getDivFunc() rPhi->addIncoming(r1, continueBB); iPhi->addIncoming(i0, mainBB); iPhi->addIncoming(i2, continueBB); - qPhi->addIncoming(Constant::get(0), mainBB); + qPhi->addIncoming(zero, mainBB); qPhi->addIncoming(q2, continueBB); m_builder.SetInsertPoint(returnBB); - auto qRet = m_builder.CreatePHI(Type::Word, 2, "q.ret"); - qRet->addIncoming(Constant::get(0), entryBB); + auto qRet = m_builder.CreatePHI(_type, 2, "q.ret"); + qRet->addIncoming(zero, entryBB); qRet->addIncoming(q1, loopBB); - auto rRet = m_builder.CreatePHI(Type::Word, 2, "r.ret"); + auto rRet = m_builder.CreatePHI(_type, 2, "r.ret"); rRet->addIncoming(r0, entryBB); rRet->addIncoming(r1, loopBB); auto ret = m_builder.CreateInsertValue(llvm::UndefValue::get(retType), qRet, 0, "ret0"); ret = m_builder.CreateInsertValue(ret, rRet, 1, "ret"); m_builder.CreateRet(ret); } - return m_div; + return func; } llvm::Function* Arith256::getExpFunc() @@ -204,6 +210,42 @@ llvm::Function* Arith256::getExpFunc() return m_exp; } +llvm::Function* Arith256::getMulModFunc() +{ + if (!m_mulmod) + { + llvm::Type* argTypes[] = {Type::Word, Type::Word, Type::Word}; + m_mulmod = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "mulmod", getModule()); + + auto i512Ty = m_builder.getIntNTy(512); + llvm::Type* mul512ArgTypes[] = {Type::WordPtr, Type::WordPtr, i512Ty->getPointerTo()}; + auto mul512 = llvm::Function::Create(llvm::FunctionType::get(Type::Void, mul512ArgTypes, false), llvm::Function::ExternalLinkage, "arith_mul512", getModule()); + + auto x = &m_mulmod->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + auto mod = y->getNextNode(); + mod->setName("mod"); + + InsertPointGuard guard{m_builder}; + + auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), {}, m_mulmod); + m_builder.SetInsertPoint(entryBB); + auto a1 = m_builder.CreateAlloca(Type::Word); + auto a2 = m_builder.CreateAlloca(Type::Word); + auto a3 = m_builder.CreateAlloca(i512Ty); + m_builder.CreateStore(x, a1); + m_builder.CreateStore(y, a2); + createCall(mul512, {a1, a2, a3}); + auto p = m_builder.CreateLoad(a3, "p"); + auto m = m_builder.CreateZExt(mod, i512Ty, "m"); + auto d = createCall(getDivFunc(i512Ty), {p, m}); + auto r = m_builder.CreateExtractValue(d, 1, "r"); + m_builder.CreateRet(r); + } + return m_mulmod; +} llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2) @@ -230,8 +272,8 @@ llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2) std::pair Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2) { - auto div = m_builder.CreateExtractValue(createCall(getDivFunc(), {_arg1, _arg2}), 0, "div"); - auto mod = m_builder.CreateExtractValue(createCall(getDivFunc(), {_arg1, _arg2}), 1, "mod"); + auto div = m_builder.CreateExtractValue(createCall(getDivFunc(Type::Word), {_arg1, _arg2}), 0, "div"); + auto mod = m_builder.CreateExtractValue(createCall(getDivFunc(Type::Word), {_arg1, _arg2}), 1, "mod"); return std::make_pair(div, mod); } @@ -270,39 +312,58 @@ llvm::Value* Arith256::addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Valu llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) { - return ternaryOp(m_mulmod, _arg1, _arg2, _arg3); + return createCall(getMulModFunc(), {_arg1, _arg2, _arg3}); } namespace { using uint128 = __uint128_t; -// uint128 add(uint128 a, uint128 b) { return a + b; } -// uint128 mul(uint128 a, uint128 b) { return a * b; } -// -// uint128 mulq(uint64_t x, uint64_t y) -// { -// return (uint128)x * (uint128)y; -// } -// -// uint128 addc(uint64_t x, uint64_t y) -// { -// return (uint128)x * (uint128)y; -// } - struct uint256 { - uint64_t lo; - uint64_t mid; - uint128 hi; + uint64_t lo = 0; + uint64_t mid = 0; + uint128 hi = 0; + + uint256(uint64_t lo, uint64_t mid, uint128 hi): lo(lo), mid(mid), hi(hi) {} + uint256(uint128 n) + { + *((uint128*)&lo) = n; + } + + explicit operator uint128() + { + return *((uint128*)&lo); + } + + uint256 operator|(uint256 a) + { + return {lo | a.lo, mid | a.mid, hi | a.hi}; + } + + uint256 operator+(uint256 a) + { + auto _lo = (uint128) lo + a.lo; + auto _mid = (uint128) mid + a.mid + (_lo >> 64); + auto _hi = hi + a.hi + (_mid >> 64); + return {(uint64_t)_lo, (uint64_t)_mid, _hi}; + } + + uint256 lo2hi() + { + hi = (uint128)*this; + lo = 0; + mid = 0; + return *this; + } }; -// uint256 add(uint256 x, uint256 y) -// { -// auto lo = (uint128) x.lo + y.lo; -// auto mid = (uint128) x.mid + y.mid + (lo >> 64); -// return {lo, mid, x.hi + y.hi + (mid >> 64)}; -// } + struct uint512 + { + uint128 lo; + uint128 mid; + uint256 hi; + }; uint256 mul(uint256 x, uint256 y) { @@ -325,6 +386,23 @@ namespace return {lo, (uint64_t)mid, hi}; } + uint512 mul512(uint256 x, uint256 y) + { + auto x_lo = (uint128) x; + auto y_lo = (uint128) y; + + auto t1 = mul(x_lo, y_lo); + auto t2 = mul(x_lo, y.hi); + auto t3 = mul(x.hi, y_lo); + auto t4 = mul(x.hi, y.hi); + + auto lo = (uint128) t1; + auto mid = (uint256) t1.hi + (uint128) t2 + (uint128) t3; + auto hi = (uint256)t2.hi + t3.hi + t4 + mid.hi; + + return {lo, (uint128)mid, hi}; + } + inline void mul(i256* x, i256* y) { auto a = (uint256*) x; @@ -376,21 +454,9 @@ extern "C" *o_result = mul(*_arg1, *_arg2); } - EXPORT void arith_mulmod(i256* _arg1, i256* _arg2, i256* _arg3, i256* o_result) + EXPORT void arith_mul512(uint256* _arg1, uint256* _arg2, uint512* o_result) { - *o_result = {}; - if (isZero(_arg3)) - return; - - mpz_t x{nLimbs, countLimbs(_arg1), reinterpret_cast(_arg1)}; - mpz_t y{nLimbs, countLimbs(_arg2), reinterpret_cast(_arg2)}; - mpz_t m{nLimbs, countLimbs(_arg3), reinterpret_cast(_arg3)}; - mpz_t z{nLimbs, 0, reinterpret_cast(o_result)}; - static mp_limb_t p_limbs[nLimbs * 2] = {}; - static mpz_t p{nLimbs * 2, 0, &p_limbs[0]}; - - mpz_mul(p, x, y); - mpz_tdiv_r(z, p, m); + *o_result = mul512(*_arg1, *_arg2); } EXPORT void arith_addmod(i256* _arg1, i256* _arg2, i256* _arg3, i256* o_result) diff --git a/libevmjit/Arith256.h b/libevmjit/Arith256.h index 74ed63043..ae87d11f2 100644 --- a/libevmjit/Arith256.h +++ b/libevmjit/Arith256.h @@ -24,18 +24,20 @@ public: void debug(llvm::Value* _value, char _c); private: - llvm::Function* getDivFunc(); + llvm::Function* getDivFunc(llvm::Type* _type); llvm::Function* getExpFunc(); + llvm::Function* getMulModFunc(); llvm::Value* binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2); llvm::Value* ternaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3); llvm::Function* m_mul; - llvm::Function* m_mulmod; llvm::Function* m_addmod; llvm::Function* m_div = nullptr; + llvm::Function* m_div512 = nullptr; llvm::Function* m_exp = nullptr; + llvm::Function* m_mulmod = nullptr; llvm::Function* m_debug = nullptr; llvm::Value* m_arg1;