diff --git a/libevmjit/Arith256.cpp b/libevmjit/Arith256.cpp index 0345a4854..a77050adc 100644 --- a/libevmjit/Arith256.cpp +++ b/libevmjit/Arith256.cpp @@ -74,7 +74,45 @@ llvm::Function* Arith256::getMulFunc() p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t8, i256), Constant::get(192))); m_builder.CreateRet(p); } - return m_mul; + return func; +} + +llvm::Function* Arith256::getMul512Func() +{ + auto& func = m_mul512; + if (!func) + { + auto i512 = m_builder.getIntNTy(512); + llvm::Type* argTypes[] = {Type::Word, Type::Word}; + func = llvm::Function::Create(llvm::FunctionType::get(i512, argTypes, false), llvm::Function::PrivateLinkage, "mul", getModule()); + + auto x = &func->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + + InsertPointGuard guard{m_builder}; + auto bb = llvm::BasicBlock::Create(m_builder.getContext(), {}, func); + m_builder.SetInsertPoint(bb); + auto i128 = m_builder.getIntNTy(128); + auto i256 = Type::Word; + auto x_lo = m_builder.CreateTrunc(x, i128, "x.lo"); + auto y_lo = m_builder.CreateTrunc(y, i128, "y.lo"); + auto x_hi = m_builder.CreateTrunc(m_builder.CreateLShr(x, Constant::get(128)), i128, "x.hi"); + auto y_hi = m_builder.CreateTrunc(m_builder.CreateLShr(y, Constant::get(128)), i128, "y.hi"); + + auto t1 = createCall(getMulFunc(), {m_builder.CreateZExt(x_lo, i256), m_builder.CreateZExt(y_lo, i256)}); + auto t2 = createCall(getMulFunc(), {m_builder.CreateZExt(x_lo, i256), m_builder.CreateZExt(y_hi, i256)}); + auto t3 = createCall(getMulFunc(), {m_builder.CreateZExt(x_hi, i256), m_builder.CreateZExt(y_lo, i256)}); + auto t4 = createCall(getMulFunc(), {m_builder.CreateZExt(x_hi, i256), m_builder.CreateZExt(y_hi, i256)}); + + auto p = m_builder.CreateZExt(t1, i512); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t2, i512), m_builder.getIntN(512, 128))); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t3, i512), m_builder.getIntN(512, 128))); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t4, i512), m_builder.getIntN(512, 256))); + m_builder.CreateRet(p); + } + return func; } llvm::Function* Arith256::getDivFunc(llvm::Type* _type) @@ -271,9 +309,6 @@ llvm::Function* Arith256::getMulModFunc() m_mulmod = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "mulmod", getModule()); auto i512Ty = m_builder.getIntNTy(512); - llvm::Type* mul512ArgTypes[] = {Type::WordPtr, Type::WordPtr, i512Ty->getPointerTo()}; - auto mul512 = llvm::Function::Create(llvm::FunctionType::get(Type::Void, mul512ArgTypes, false), llvm::Function::ExternalLinkage, "arith_mul512", getModule()); - auto x = &m_mulmod->getArgumentList().front(); x->setName("x"); auto y = x->getNextNode(); @@ -285,13 +320,7 @@ llvm::Function* Arith256::getMulModFunc() auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), {}, m_mulmod); m_builder.SetInsertPoint(entryBB); - auto a1 = m_builder.CreateAlloca(Type::Word); - auto a2 = m_builder.CreateAlloca(Type::Word); - auto a3 = m_builder.CreateAlloca(i512Ty); - m_builder.CreateStore(x, a1); - m_builder.CreateStore(y, a2); - createCall(mul512, {a1, a2, a3}); - auto p = m_builder.CreateLoad(a3, "p"); + auto p = createCall(getMul512Func(), {x, y}); auto m = m_builder.CreateZExt(mod, i512Ty, "m"); auto d = createCall(getDivFunc(i512Ty), {p, m}); auto r = m_builder.CreateExtractValue(d, 1, "r"); @@ -350,157 +379,6 @@ llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Valu return createCall(getMulModFunc(), {_arg1, _arg2, _arg3}); } -namespace -{ -#ifdef __SIZEOF_INT128__ - using uint128 = __uint128_t; -#else - struct uint128 - { - uint64_t lo = 0; - uint64_t hi = 0; - - uint128(uint64_t lo) : lo(lo) {} - - uint128 operator+(uint128 a) - { - uint128 r = 0; - bool overflow = lo > std::numeric_limits::max() - a.lo; - r.lo = lo + a.lo; - r.hi = hi + a.hi + overflow; - return r; - } - - uint128 operator>>(int s) - { - assert(s == 64); - return hi; - } - - uint128 operator<<(int s) - { - assert(s == 64); - uint128 r = 0; - r.hi = lo; - return r; - } - - explicit operator uint64_t() { return lo; } - - static uint128 mul(uint64_t a, uint64_t b) - { - auto x_lo = 0xFFFFFFFF & a; - auto y_lo = 0xFFFFFFFF & b; - auto x_hi = a >> 32; - auto y_hi = b >> 32; - - auto t1 = x_lo * y_lo; - auto t2 = x_lo * y_hi; - auto t3 = x_hi * y_lo; - auto t4 = x_hi * y_hi; - - auto lo = (uint32_t)t1; - auto mid = (uint64_t)(t1 >> 32) + (uint32_t)t2 + (uint32_t)t3; - auto hi = (uint64_t)(t2 >> 32) + (t3 >> 32) + t4 + (mid >> 32); - - uint128 r = 0; - r.lo = (uint64_t)lo + (mid << 32); - r.hi = hi; - return r; - } - - uint128 operator*(uint128 a) - { - auto t1 = mul(lo, a.lo); - auto t2 = mul(lo, a.hi); - auto t3 = mul(hi, a.lo); - return t1 + (t2 << 64) + (t3 << 64); - } - }; -#endif - - struct uint256 - { - uint64_t lo = 0; - uint64_t mid = 0; - uint128 hi = 0; - - uint256(uint64_t lo, uint64_t mid, uint128 hi): lo(lo), mid(mid), hi(hi) {} - uint256(uint128 n) - { - lo = (uint64_t) n; - mid = (uint64_t) (n >> 64); - } - - explicit operator uint128() - { - uint128 r = lo; - r |= ((uint128) mid) << 64; - return r; - } - - uint256 operator+(uint256 a) - { - auto _lo = (uint128) lo + a.lo; - auto _mid = (uint128) mid + a.mid + (_lo >> 64); - auto _hi = hi + a.hi + (_mid >> 64); - return {(uint64_t)_lo, (uint64_t)_mid, _hi}; - } - - uint256 lo2hi() - { - hi = (uint128)*this; - lo = 0; - mid = 0; - return *this; - } - }; - - struct uint512 - { - uint128 lo; - uint128 mid; - uint256 hi; - }; - - uint256 mul(uint256 x, uint256 y) - { - auto t1 = (uint128) x.lo * y.lo; - auto t2 = (uint128) x.lo * y.mid; - auto t3 = (uint128) x.lo * y.hi; - auto t4 = (uint128) x.mid * y.lo; - auto t5 = (uint128) x.mid * y.mid; - auto t6 = (uint128) x.mid * y.hi; - auto t7 = x.hi * y.lo; - auto t8 = x.hi * y.mid; - - auto lo = (uint64_t) t1; - auto m1 = (t1 >> 64) + (uint64_t) t2; - auto m2 = (uint64_t) m1; - auto mid = (uint128) m2 + (uint64_t) t4; - auto hi = (t2 >> 64) + t3 + (t4 >> 64) + t5 + (t6 << 64) + t7 - + (t8 << 64) + (m1 >> 64) + (mid >> 64); - - return {lo, (uint64_t)mid, hi}; - } - - uint512 mul512(uint256 x, uint256 y) - { - auto x_lo = (uint128) x; - auto y_lo = (uint128) y; - - auto t1 = mul(x_lo, y_lo); - auto t2 = mul(x_lo, y.hi); - auto t3 = mul(x.hi, y_lo); - auto t4 = mul(x.hi, y.hi); - - auto lo = (uint128) t1; - auto mid = (uint256) t1.hi + (uint128) t2 + (uint128) t3; - auto hi = (uint256)t2.hi + t3.hi + t4 + mid.hi; - - return {lo, (uint128)mid, hi}; - } -} } } @@ -508,15 +386,8 @@ namespace extern "C" { - using namespace dev::eth::jit; - EXPORT void debug(uint64_t a, uint64_t b, uint64_t c, uint64_t d, char z) { std::cerr << "DEBUG " << z << ": " << d << c << b << a << std::endl; } - - EXPORT void arith_mul512(uint256* _arg1, uint256* _arg2, uint512* o_result) - { - *o_result = mul512(*_arg1, *_arg2); - } } diff --git a/libevmjit/Arith256.h b/libevmjit/Arith256.h index f8f1c9eb2..2513ca568 100644 --- a/libevmjit/Arith256.h +++ b/libevmjit/Arith256.h @@ -25,12 +25,14 @@ public: private: llvm::Function* getMulFunc(); + llvm::Function* getMul512Func(); llvm::Function* getDivFunc(llvm::Type* _type); llvm::Function* getExpFunc(); llvm::Function* getAddModFunc(); llvm::Function* getMulModFunc(); llvm::Function* m_mul = nullptr; + llvm::Function* m_mul512 = nullptr; llvm::Function* m_div = nullptr; llvm::Function* m_div512 = nullptr; llvm::Function* m_exp = nullptr;