From d8da43e939fa613abfeaaae3cf1bc8963562318d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Bylica?= Date: Wed, 13 May 2015 18:55:14 +0200 Subject: [PATCH] Lower ADDMOD & MULMOD (limited) to a function call in the LLVM pass after optimization. --- libevmjit/Arith256.cpp | 183 +++++++++------------------------------- libevmjit/Arith256.h | 9 +- libevmjit/Compiler.cpp | 55 +++++++----- libevmjit/Optimizer.cpp | 10 +++ 4 files changed, 87 insertions(+), 170 deletions(-) diff --git a/libevmjit/Arith256.cpp b/libevmjit/Arith256.cpp index 64ed7e649..472c699a8 100644 --- a/libevmjit/Arith256.cpp +++ b/libevmjit/Arith256.cpp @@ -86,6 +86,47 @@ llvm::Function* Arith256::getMulFunc(llvm::Module& _module) return func; } +llvm::Function* Arith256::getMul512Func(llvm::Module& _module) +{ + static const auto funcName = "evm.mul.i512"; + if (auto func = _module.getFunction(funcName)) + return func; + + auto i512Ty = llvm::IntegerType::get(_module.getContext(), 512); + auto func = llvm::Function::Create(llvm::FunctionType::get(i512Ty, {Type::Word, Type::Word}, false), llvm::Function::PrivateLinkage, funcName, &_module); + func->setDoesNotThrow(); + func->setDoesNotAccessMemory(); + + auto x = &func->getArgumentList().front(); + x->setName("x"); + auto y = x->getNextNode(); + y->setName("y"); + + auto bb = llvm::BasicBlock::Create(_module.getContext(), {}, func); + auto builder = llvm::IRBuilder<>{bb}; + + auto i128 = builder.getIntNTy(128); + auto i256 = Type::Word; + auto x_lo = builder.CreateZExt(builder.CreateTrunc(x, i128, "x.lo"), i256); + auto y_lo = builder.CreateZExt(builder.CreateTrunc(y, i128, "y.lo"), i256); + auto x_hi = builder.CreateZExt(builder.CreateTrunc(builder.CreateLShr(x, Constant::get(128)), i128, "x.hi"), i256); + auto y_hi = builder.CreateZExt(builder.CreateTrunc(builder.CreateLShr(y, Constant::get(128)), i128, "y.hi"), i256); + + auto mul256Func = getMulFunc(_module); + auto t1 = builder.CreateCall(mul256Func, {x_lo, y_lo}); + auto t2 = builder.CreateCall(mul256Func, {x_lo, y_hi}); + auto t3 = builder.CreateCall(mul256Func, {x_hi, y_lo}); + auto t4 = builder.CreateCall(mul256Func, {x_hi, y_hi}); + + auto p = builder.CreateZExt(t1, i512Ty); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t2, i512Ty), builder.getIntN(512, 128))); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t3, i512Ty), builder.getIntN(512, 128))); + p = builder.CreateAdd(p, builder.CreateShl(builder.CreateZExt(t4, i512Ty), builder.getIntN(512, 256))); + builder.CreateRet(p); + + return func; +} + namespace { llvm::Function* createUDivRemFunc(llvm::Type* _type, llvm::Module& _module, char const* _funcName) @@ -354,47 +395,6 @@ llvm::Function* Arith256::getSRem256Func(llvm::Module& _module) return func; } -llvm::Function* Arith256::getMul512Func() -{ - auto& func = m_mul512; - if (!func) - { - auto i512 = m_builder.getIntNTy(512); - llvm::Type* argTypes[] = {Type::Word, Type::Word}; - func = llvm::Function::Create(llvm::FunctionType::get(i512, argTypes, false), llvm::Function::PrivateLinkage, "mul512", getModule()); - func->setDoesNotThrow(); - func->setDoesNotAccessMemory(); - - auto x = &func->getArgumentList().front(); - x->setName("x"); - auto y = x->getNextNode(); - y->setName("y"); - - InsertPointGuard guard{m_builder}; - auto bb = llvm::BasicBlock::Create(m_builder.getContext(), {}, func); - m_builder.SetInsertPoint(bb); - auto i128 = m_builder.getIntNTy(128); - auto i256 = Type::Word; - auto x_lo = m_builder.CreateZExt(m_builder.CreateTrunc(x, i128, "x.lo"), i256); - auto y_lo = m_builder.CreateZExt(m_builder.CreateTrunc(y, i128, "y.lo"), i256); - auto x_hi = m_builder.CreateZExt(m_builder.CreateTrunc(m_builder.CreateLShr(x, Constant::get(128)), i128, "x.hi"), i256); - auto y_hi = m_builder.CreateZExt(m_builder.CreateTrunc(m_builder.CreateLShr(y, Constant::get(128)), i128, "y.hi"), i256); - - auto mul256Func = getMulFunc(*getModule()); - auto t1 = createCall(mul256Func, {x_lo, y_lo}); - auto t2 = createCall(mul256Func, {x_lo, y_hi}); - auto t3 = createCall(mul256Func, {x_hi, y_lo}); - auto t4 = createCall(mul256Func, {x_hi, y_hi}); - - auto p = m_builder.CreateZExt(t1, i512); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t2, i512), m_builder.getIntN(512, 128))); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t3, i512), m_builder.getIntN(512, 128))); - p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t4, i512), m_builder.getIntN(512, 256))); - m_builder.CreateRet(p); - } - return func; -} - llvm::Function* Arith256::getExpFunc() { if (!m_exp) @@ -465,66 +465,6 @@ llvm::Function* Arith256::getExpFunc() return m_exp; } -llvm::Function* Arith256::getAddModFunc() -{ - if (!m_addmod) - { - auto i512Ty = m_builder.getIntNTy(512); - llvm::Type* argTypes[] = {Type::Word, Type::Word, Type::Word}; - m_addmod = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "addmod", getModule()); - m_addmod->setDoesNotThrow(); - m_addmod->setDoesNotAccessMemory(); - - auto x = &m_addmod->getArgumentList().front(); - x->setName("x"); - auto y = x->getNextNode(); - y->setName("y"); - auto mod = y->getNextNode(); - mod->setName("m"); - - InsertPointGuard guard{m_builder}; - - auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), {}, m_addmod); - m_builder.SetInsertPoint(entryBB); - auto x512 = m_builder.CreateZExt(x, i512Ty, "x512"); - auto y512 = m_builder.CreateZExt(y, i512Ty, "y512"); - auto m512 = m_builder.CreateZExt(mod, i512Ty, "m512"); - auto s = m_builder.CreateAdd(x512, y512, "s"); - auto r = createCall(getURem512Func(*getModule()), {s, m512}); - m_builder.CreateRet(m_builder.CreateTrunc(r, Type::Word)); - } - return m_addmod; -} - -llvm::Function* Arith256::getMulModFunc() -{ - if (!m_mulmod) - { - llvm::Type* argTypes[] = {Type::Word, Type::Word, Type::Word}; - m_mulmod = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "mulmod", getModule()); - m_mulmod->setDoesNotThrow(); - m_mulmod->setDoesNotAccessMemory(); - - auto i512Ty = m_builder.getIntNTy(512); - auto x = &m_mulmod->getArgumentList().front(); - x->setName("x"); - auto y = x->getNextNode(); - y->setName("y"); - auto mod = y->getNextNode(); - mod->setName("mod"); - - InsertPointGuard guard{m_builder}; - - auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), {}, m_mulmod); - m_builder.SetInsertPoint(entryBB); - auto p = createCall(getMul512Func(), {x, y}); - auto m = m_builder.CreateZExt(mod, i512Ty, "m"); - auto r = createCall(getURem512Func(*getModule()), {p, m}); - m_builder.CreateRet(m_builder.CreateTrunc(r, Type::Word)); - } - return m_mulmod; -} - llvm::Value* Arith256::exp(llvm::Value* _arg1, llvm::Value* _arg2) { // while (e != 0) { @@ -555,47 +495,6 @@ llvm::Value* Arith256::exp(llvm::Value* _arg1, llvm::Value* _arg2) return createCall(getExpFunc(), {_arg1, _arg2}); } -llvm::Value* Arith256::addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) -{ - if (auto c1 = llvm::dyn_cast(_arg1)) - { - if (auto c2 = llvm::dyn_cast(_arg2)) - { - if (auto c3 = llvm::dyn_cast(_arg3)) - { - if (!c3->getValue()) - return Constant::get(0); - auto s = c1->getValue().zext(256+64) + c2->getValue().zext(256+64); - auto r = s.urem(c3->getValue().zext(256+64)).trunc(256); - return Constant::get(r); - } - } - } - - return createCall(getAddModFunc(), {_arg1, _arg2, _arg3}); -} - -llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) -{ - if (auto c1 = llvm::dyn_cast(_arg1)) - { - if (auto c2 = llvm::dyn_cast(_arg2)) - { - if (auto c3 = llvm::dyn_cast(_arg3)) - { - if (!c3->getValue()) - return Constant::get(0); - auto p = c1->getValue().zext(512) * c2->getValue().zext(512); - auto r = p.urem(c3->getValue().zext(512)).trunc(256); - return Constant::get(r); - } - } - } - - return createCall(getMulModFunc(), {_arg1, _arg2, _arg3}); -} - - } } } diff --git a/libevmjit/Arith256.h b/libevmjit/Arith256.h index d6096a4c2..81535a792 100644 --- a/libevmjit/Arith256.h +++ b/libevmjit/Arith256.h @@ -15,12 +15,11 @@ public: Arith256(llvm::IRBuilder<>& _builder); llvm::Value* exp(llvm::Value* _arg1, llvm::Value* _arg2); - llvm::Value* mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3); - llvm::Value* addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3); void debug(llvm::Value* _value, char _c); static llvm::Function* getMulFunc(llvm::Module& _module); + static llvm::Function* getMul512Func(llvm::Module& _module); static llvm::Function* getUDiv256Func(llvm::Module& _module); static llvm::Function* getURem256Func(llvm::Module& _module); static llvm::Function* getURem512Func(llvm::Module& _module); @@ -31,15 +30,9 @@ public: static llvm::Function* getUDivRem512Func(llvm::Module& _module); private: - llvm::Function* getMul512Func(); llvm::Function* getExpFunc(); - llvm::Function* getAddModFunc(); - llvm::Function* getMulModFunc(); - llvm::Function* m_mul512 = nullptr; llvm::Function* m_exp = nullptr; - llvm::Function* m_addmod = nullptr; - llvm::Function* m_mulmod = nullptr; llvm::Function* m_debug = nullptr; }; diff --git a/libevmjit/Compiler.cpp b/libevmjit/Compiler.cpp index 33d802f90..56f1904a1 100644 --- a/libevmjit/Compiler.cpp +++ b/libevmjit/Compiler.cpp @@ -322,6 +322,41 @@ void Compiler::compileBasicBlock(BasicBlock& _basicBlock, RuntimeManager& _runti break; } + case Instruction::ADDMOD: + { + auto i512Ty = m_builder.getIntNTy(512); + auto a = stack.pop(); + auto b = stack.pop(); + auto m = stack.pop(); + auto divByZero = m_builder.CreateICmpEQ(m, Constant::get(0)); + a = m_builder.CreateZExt(a, i512Ty); + b = m_builder.CreateZExt(b, i512Ty); + m = m_builder.CreateZExt(m, i512Ty); + auto s = m_builder.CreateNUWAdd(a, b); + s = m_builder.CreateURem(s, m); + s = m_builder.CreateTrunc(s, Type::Word); + s = m_builder.CreateSelect(divByZero, Constant::get(0), s); + stack.push(s); + break; + } + + case Instruction::MULMOD: + { + auto i512Ty = m_builder.getIntNTy(512); + auto a = stack.pop(); + auto b = stack.pop(); + auto m = stack.pop(); + auto divByZero = m_builder.CreateICmpEQ(m, Constant::get(0)); + m = m_builder.CreateZExt(m, i512Ty); + // TODO: Add support for i256 x i256 -> i512 in LowerEVM pass + llvm::Value* p = m_builder.CreateCall(Arith256::getMul512Func(*_basicBlock.llvm()->getParent()->getParent()), {a, b}); + p = m_builder.CreateURem(p, m); + p = m_builder.CreateTrunc(p, Type::Word); + p = m_builder.CreateSelect(divByZero, Constant::get(0), p); + stack.push(p); + break; + } + case Instruction::EXP: { auto base = stack.pop(); @@ -440,26 +475,6 @@ void Compiler::compileBasicBlock(BasicBlock& _basicBlock, RuntimeManager& _runti break; } - case Instruction::ADDMOD: - { - auto lhs = stack.pop(); - auto rhs = stack.pop(); - auto mod = stack.pop(); - auto res = _arith.addmod(lhs, rhs, mod); - stack.push(res); - break; - } - - case Instruction::MULMOD: - { - auto lhs = stack.pop(); - auto rhs = stack.pop(); - auto mod = stack.pop(); - auto res = _arith.mulmod(lhs, rhs, mod); - stack.push(res); - break; - } - case Instruction::SIGNEXTEND: { auto idx = stack.pop(); diff --git a/libevmjit/Optimizer.cpp b/libevmjit/Optimizer.cpp index 8eaab9a99..52bf14efa 100644 --- a/libevmjit/Optimizer.cpp +++ b/libevmjit/Optimizer.cpp @@ -54,6 +54,7 @@ bool LowerEVMPass::runOnBasicBlock(llvm::BasicBlock& _bb) { auto modified = false; auto module = _bb.getParent()->getParent(); + auto i512Ty = llvm::IntegerType::get(_bb.getContext(), 512); for (auto it = _bb.begin(); it != _bb.end(); ) { auto& inst = *it++; @@ -83,6 +84,15 @@ bool LowerEVMPass::runOnBasicBlock(llvm::BasicBlock& _bb) break; } } + else if (inst.getType() == i512Ty) + { + switch (inst.getOpcode()) + { + case llvm::Instruction::URem: + func = Arith256::getURem512Func(*module); + break; + } + } if (func) {