diff --git a/libevmjit/Arith256.cpp b/libevmjit/Arith256.cpp index 2ec227dcc..0345a4854 100644 --- a/libevmjit/Arith256.cpp +++ b/libevmjit/Arith256.cpp @@ -5,6 +5,7 @@ #include #include + #include namespace dev @@ -16,20 +17,7 @@ namespace jit Arith256::Arith256(llvm::IRBuilder<>& _builder) : CompilerHelper(_builder) -{ - using namespace llvm; - - m_result = m_builder.CreateAlloca(Type::Word, nullptr, "arith.result"); - m_arg1 = m_builder.CreateAlloca(Type::Word, nullptr, "arith.arg1"); - m_arg2 = m_builder.CreateAlloca(Type::Word, nullptr, "arith.arg2"); - m_arg3 = m_builder.CreateAlloca(Type::Word, nullptr, "arith.arg3"); - - using Linkage = GlobalValue::LinkageTypes; - - llvm::Type* arg2Types[] = {Type::WordPtr, Type::WordPtr, Type::WordPtr}; - - m_mul = Function::Create(FunctionType::get(Type::Void, arg2Types, false), Linkage::ExternalLinkage, "arith_mul", getModule()); -} +{} void Arith256::debug(llvm::Value* _value, char _c) { @@ -41,6 +29,54 @@ void Arith256::debug(llvm::Value* _value, char _c) createCall(m_debug, {_value, m_builder.getInt8(_c)}); } +llvm::Function* Arith256::getMulFunc() +{ + auto& func = m_mul; + if (!func) + { + llvm::Type* argTypes[] = {Type::Word, Type::Word}; + func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "mul", getModule()); + + auto x = &func->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + + InsertPointGuard guard{m_builder}; + auto bb = llvm::BasicBlock::Create(m_builder.getContext(), {}, func); + m_builder.SetInsertPoint(bb); + auto i64 = Type::Size; + auto i128 = m_builder.getIntNTy(128); + auto i256 = Type::Word; + auto x_lo = m_builder.CreateTrunc(x, i64, "x.lo"); + auto y_lo = m_builder.CreateTrunc(y, i64, "y.lo"); + auto x_mi = m_builder.CreateTrunc(m_builder.CreateLShr(x, Constant::get(64)), i64); + auto y_mi = m_builder.CreateTrunc(m_builder.CreateLShr(y, Constant::get(64)), i64); + auto x_hi = m_builder.CreateTrunc(m_builder.CreateLShr(x, Constant::get(128)), i128); + auto y_hi = m_builder.CreateTrunc(m_builder.CreateLShr(y, Constant::get(128)), i128); + + auto t1 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), m_builder.CreateZExt(y_lo, i128)); + auto t2 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), m_builder.CreateZExt(y_mi, i128)); + auto t3 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), y_hi); + auto t4 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), m_builder.CreateZExt(y_lo, i128)); + auto t5 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), m_builder.CreateZExt(y_mi, i128)); + auto t6 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), y_hi); + auto t7 = m_builder.CreateMul(x_hi, m_builder.CreateZExt(y_lo, i128)); + auto t8 = m_builder.CreateMul(x_hi, m_builder.CreateZExt(y_mi, i128)); + + auto p = m_builder.CreateZExt(t1, i256); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t2, i256), Constant::get(64))); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t3, i256), Constant::get(128))); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t4, i256), Constant::get(64))); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t5, i256), Constant::get(128))); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t6, i256), Constant::get(192))); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t7, i256), Constant::get(128))); + p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t8, i256), Constant::get(192))); + m_builder.CreateRet(p); + } + return m_mul; +} + llvm::Function* Arith256::getDivFunc(llvm::Type* _type) { auto& func = _type == Type::Word ? m_div : m_div512; @@ -135,7 +171,7 @@ llvm::Function* Arith256::getExpFunc() if (!m_exp) { llvm::Type* argTypes[] = {Type::Word, Type::Word}; - m_exp = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "arith.exp", getModule()); + m_exp = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "exp", getModule()); auto base = &m_exp->getArgumentList().front(); base->setName("base"); @@ -159,9 +195,6 @@ llvm::Function* Arith256::getExpFunc() auto returnBB = llvm::BasicBlock::Create(m_builder.getContext(), "Return", m_exp); m_builder.SetInsertPoint(entryBB); - auto a1 = m_builder.CreateAlloca(Type::Word, nullptr, "a1"); - auto a2 = m_builder.CreateAlloca(Type::Word, nullptr, "a2"); - auto a3 = m_builder.CreateAlloca(Type::Word, nullptr, "a3"); m_builder.CreateBr(headerBB); m_builder.SetInsertPoint(headerBB); @@ -176,20 +209,14 @@ llvm::Function* Arith256::getExpFunc() m_builder.CreateCondBr(eOdd, updateBB, continueBB); m_builder.SetInsertPoint(updateBB); - m_builder.CreateStore(r, a1); - m_builder.CreateStore(b, a2); - createCall(m_mul, {a1, a2, a3}); - auto r0 = m_builder.CreateLoad(a3, "r0"); + auto r0 = createCall(getMulFunc(), {r, b}); m_builder.CreateBr(continueBB); m_builder.SetInsertPoint(continueBB); auto r1 = m_builder.CreatePHI(Type::Word, 2, "r1"); r1->addIncoming(r, bodyBB); r1->addIncoming(r0, updateBB); - m_builder.CreateStore(b, a1); - m_builder.CreateStore(b, a2); - createCall(m_mul, {a1, a2, a3}); - auto b1 = m_builder.CreateLoad(a3, "b1"); + auto b1 = createCall(getMulFunc(), {b, b}); auto e1 = m_builder.CreateLShr(e, Constant::get(1), "e1"); m_builder.CreateBr(headerBB); @@ -273,17 +300,9 @@ llvm::Function* Arith256::getMulModFunc() return m_mulmod; } -llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2) -{ - m_builder.CreateStore(_arg1, m_arg1); - m_builder.CreateStore(_arg2, m_arg2); - m_builder.CreateCall3(_op, m_arg1, m_arg2, m_result); - return m_builder.CreateLoad(m_result); -} - llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2) { - return binaryOp(m_mul, _arg1, _arg2); + return createCall(getMulFunc(), {_arg1, _arg2}); } std::pair Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2) @@ -496,11 +515,6 @@ extern "C" std::cerr << "DEBUG " << z << ": " << d << c << b << a << std::endl; } - EXPORT void arith_mul(uint256* _arg1, uint256* _arg2, uint256* o_result) - { - *o_result = mul(*_arg1, *_arg2); - } - EXPORT void arith_mul512(uint256* _arg1, uint256* _arg2, uint512* o_result) { *o_result = mul512(*_arg1, *_arg2); diff --git a/libevmjit/Arith256.h b/libevmjit/Arith256.h index 5852137f8..f8f1c9eb2 100644 --- a/libevmjit/Arith256.h +++ b/libevmjit/Arith256.h @@ -24,26 +24,19 @@ public: void debug(llvm::Value* _value, char _c); private: + llvm::Function* getMulFunc(); llvm::Function* getDivFunc(llvm::Type* _type); llvm::Function* getExpFunc(); llvm::Function* getAddModFunc(); llvm::Function* getMulModFunc(); - llvm::Value* binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2); - - llvm::Function* m_mul; - + llvm::Function* m_mul = nullptr; llvm::Function* m_div = nullptr; llvm::Function* m_div512 = nullptr; llvm::Function* m_exp = nullptr; llvm::Function* m_addmod = nullptr; llvm::Function* m_mulmod = nullptr; llvm::Function* m_debug = nullptr; - - llvm::Value* m_arg1; - llvm::Value* m_arg2; - llvm::Value* m_arg3; - llvm::Value* m_result; };