Browse Source

New mulmod algorithm

cl-refactor
Paweł Bylica 10 years ago
parent
commit
6c2aa13e11
  1. 202
      libevmjit/Arith256.cpp
  2. 6
      libevmjit/Arith256.h

202
libevmjit/Arith256.cpp

@ -7,6 +7,7 @@
#include <llvm/IR/IntrinsicInst.h> #include <llvm/IR/IntrinsicInst.h>
#include <gmp.h> #include <gmp.h>
#include <iostream> #include <iostream>
#include <iomanip>
namespace dev namespace dev
{ {
@ -32,7 +33,6 @@ Arith256::Arith256(llvm::IRBuilder<>& _builder) :
m_mul = Function::Create(FunctionType::get(Type::Void, arg2Types, false), Linkage::ExternalLinkage, "arith_mul", getModule()); m_mul = Function::Create(FunctionType::get(Type::Void, arg2Types, false), Linkage::ExternalLinkage, "arith_mul", getModule());
m_addmod = Function::Create(FunctionType::get(Type::Void, arg3Types, false), Linkage::ExternalLinkage, "arith_addmod", getModule()); m_addmod = Function::Create(FunctionType::get(Type::Void, arg3Types, false), Linkage::ExternalLinkage, "arith_addmod", getModule());
m_mulmod = Function::Create(FunctionType::get(Type::Void, arg3Types, false), Linkage::ExternalLinkage, "arith_mulmod", getModule());
} }
void Arith256::debug(llvm::Value* _value, char _c) void Arith256::debug(llvm::Value* _value, char _c)
@ -45,64 +45,70 @@ void Arith256::debug(llvm::Value* _value, char _c)
createCall(m_debug, {_value, m_builder.getInt8(_c)}); createCall(m_debug, {_value, m_builder.getInt8(_c)});
} }
llvm::Function* Arith256::getDivFunc() llvm::Function* Arith256::getDivFunc(llvm::Type* _type)
{ {
if (!m_div) auto& func = _type == Type::Word ? m_div : m_div512;
if (!func)
{ {
// Based of "Improved shift divisor algorithm" from "Software Integer Division" by Microsoft Research // Based of "Improved shift divisor algorithm" from "Software Integer Division" by Microsoft Research
// The following algorithm also handles divisor of value 0 returning 0 for both quotient and reminder // The following algorithm also handles divisor of value 0 returning 0 for both quotient and reminder
llvm::Type* argTypes[] = {Type::Word, Type::Word}; llvm::Type* argTypes[] = {_type, _type};
auto retType = llvm::StructType::get(m_builder.getContext(), llvm::ArrayRef<llvm::Type*>{argTypes}); auto retType = llvm::StructType::get(m_builder.getContext(), llvm::ArrayRef<llvm::Type*>{argTypes});
m_div = llvm::Function::Create(llvm::FunctionType::get(retType, argTypes, false), llvm::Function::PrivateLinkage, "arith.div", getModule()); auto funcName = _type == Type::Word ? "div" : "div512";
func = llvm::Function::Create(llvm::FunctionType::get(retType, argTypes, false), llvm::Function::PrivateLinkage, funcName, getModule());
auto zero = llvm::ConstantInt::get(_type, 0);
auto one = llvm::ConstantInt::get(_type, 1);
auto x = &m_div->getArgumentList().front(); auto x = &func->getArgumentList().front();
x->setName("x"); x->setName("x");
auto yArg = x->getNextNode(); auto yArg = x->getNextNode();
yArg->setName("y"); yArg->setName("y");
InsertPointGuard guard{m_builder}; InsertPointGuard guard{m_builder};
auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), "Entry", m_div); auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), "Entry", func);
auto mainBB = llvm::BasicBlock::Create(m_builder.getContext(), "Main", m_div); auto mainBB = llvm::BasicBlock::Create(m_builder.getContext(), "Main", func);
auto loopBB = llvm::BasicBlock::Create(m_builder.getContext(), "Loop", m_div); auto loopBB = llvm::BasicBlock::Create(m_builder.getContext(), "Loop", func);
auto continueBB = llvm::BasicBlock::Create(m_builder.getContext(), "Continue", m_div); auto continueBB = llvm::BasicBlock::Create(m_builder.getContext(), "Continue", func);
auto returnBB = llvm::BasicBlock::Create(m_builder.getContext(), "Return", m_div); auto returnBB = llvm::BasicBlock::Create(m_builder.getContext(), "Return", func);
m_builder.SetInsertPoint(entryBB); m_builder.SetInsertPoint(entryBB);
auto yNonZero = m_builder.CreateICmpNE(yArg, Constant::get(0)); auto yNonZero = m_builder.CreateICmpNE(yArg, zero);
auto yLEx = m_builder.CreateICmpULE(yArg, x); auto yLEx = m_builder.CreateICmpULE(yArg, x);
auto r0 = m_builder.CreateSelect(yNonZero, x, Constant::get(0), "r0"); auto r0 = m_builder.CreateSelect(yNonZero, x, zero, "r0");
m_builder.CreateCondBr(m_builder.CreateAnd(yLEx, yNonZero), mainBB, returnBB); m_builder.CreateCondBr(m_builder.CreateAnd(yLEx, yNonZero), mainBB, returnBB);
m_builder.SetInsertPoint(mainBB); m_builder.SetInsertPoint(mainBB);
auto ctlzIntr = llvm::Intrinsic::getDeclaration(getModule(), llvm::Intrinsic::ctlz, Type::Word); auto ctlzIntr = llvm::Intrinsic::getDeclaration(getModule(), llvm::Intrinsic::ctlz, _type);
// both y and r are non-zero // both y and r are non-zero
auto yLz = m_builder.CreateCall2(ctlzIntr, yArg, m_builder.getInt1(true), "y.lz"); auto yLz = m_builder.CreateCall2(ctlzIntr, yArg, m_builder.getInt1(true), "y.lz");
auto rLz = m_builder.CreateCall2(ctlzIntr, r0, m_builder.getInt1(true), "r.lz"); auto rLz = m_builder.CreateCall2(ctlzIntr, r0, m_builder.getInt1(true), "r.lz");
auto i0 = m_builder.CreateNUWSub(yLz, rLz, "i0"); auto i0 = m_builder.CreateNUWSub(yLz, rLz, "i0");
auto shlBy0 = m_builder.CreateICmpEQ(i0, Constant::get(0)); auto shlBy0 = m_builder.CreateICmpEQ(i0, zero);
auto y0 = m_builder.CreateShl(yArg, i0); auto y0 = m_builder.CreateShl(yArg, i0);
y0 = m_builder.CreateSelect(shlBy0, yArg, y0, "y0"); // Workaround for LLVM bug: shl by 0 produces wrong result y0 = m_builder.CreateSelect(shlBy0, yArg, y0, "y0"); // Workaround for LLVM bug: shl by 0 produces wrong result
m_builder.CreateBr(loopBB); m_builder.CreateBr(loopBB);
m_builder.SetInsertPoint(loopBB); m_builder.SetInsertPoint(loopBB);
auto yPhi = m_builder.CreatePHI(Type::Word, 2, "y.phi"); auto yPhi = m_builder.CreatePHI(_type, 2, "y.phi");
auto rPhi = m_builder.CreatePHI(Type::Word, 2, "r.phi"); auto rPhi = m_builder.CreatePHI(_type, 2, "r.phi");
auto iPhi = m_builder.CreatePHI(Type::Word, 2, "i.phi"); auto iPhi = m_builder.CreatePHI(_type, 2, "i.phi");
auto qPhi = m_builder.CreatePHI(Type::Word, 2, "q.phi"); auto qPhi = m_builder.CreatePHI(_type, 2, "q.phi");
auto rUpdate = m_builder.CreateNUWSub(rPhi, yPhi); auto rUpdate = m_builder.CreateNUWSub(rPhi, yPhi);
auto qUpdate = m_builder.CreateOr(qPhi, Constant::get(1)); // q += 1, q lowest bit is 0 auto qUpdate = m_builder.CreateOr(qPhi, one); // q += 1, q lowest bit is 0
auto rGEy = m_builder.CreateICmpUGE(rPhi, yPhi); auto rGEy = m_builder.CreateICmpUGE(rPhi, yPhi);
auto r1 = m_builder.CreateSelect(rGEy, rUpdate, rPhi, "r1"); auto r1 = m_builder.CreateSelect(rGEy, rUpdate, rPhi, "r1");
auto q1 = m_builder.CreateSelect(rGEy, qUpdate, qPhi, "q"); auto q1 = m_builder.CreateSelect(rGEy, qUpdate, qPhi, "q");
auto iZero = m_builder.CreateICmpEQ(iPhi, Constant::get(0)); auto iZero = m_builder.CreateICmpEQ(iPhi, zero);
m_builder.CreateCondBr(iZero, returnBB, continueBB); m_builder.CreateCondBr(iZero, returnBB, continueBB);
m_builder.SetInsertPoint(continueBB); m_builder.SetInsertPoint(continueBB);
auto i2 = m_builder.CreateNUWSub(iPhi, Constant::get(1)); auto i2 = m_builder.CreateNUWSub(iPhi, one);
auto q2 = m_builder.CreateShl(q1, Constant::get(1)); auto q2 = m_builder.CreateShl(q1, one);
auto y2 = m_builder.CreateUDiv(yPhi, Constant::get(2)); auto y2 = m_builder.CreateLShr(yPhi, one);
m_builder.CreateBr(loopBB); m_builder.CreateBr(loopBB);
yPhi->addIncoming(y0, mainBB); yPhi->addIncoming(y0, mainBB);
@ -111,21 +117,21 @@ llvm::Function* Arith256::getDivFunc()
rPhi->addIncoming(r1, continueBB); rPhi->addIncoming(r1, continueBB);
iPhi->addIncoming(i0, mainBB); iPhi->addIncoming(i0, mainBB);
iPhi->addIncoming(i2, continueBB); iPhi->addIncoming(i2, continueBB);
qPhi->addIncoming(Constant::get(0), mainBB); qPhi->addIncoming(zero, mainBB);
qPhi->addIncoming(q2, continueBB); qPhi->addIncoming(q2, continueBB);
m_builder.SetInsertPoint(returnBB); m_builder.SetInsertPoint(returnBB);
auto qRet = m_builder.CreatePHI(Type::Word, 2, "q.ret"); auto qRet = m_builder.CreatePHI(_type, 2, "q.ret");
qRet->addIncoming(Constant::get(0), entryBB); qRet->addIncoming(zero, entryBB);
qRet->addIncoming(q1, loopBB); qRet->addIncoming(q1, loopBB);
auto rRet = m_builder.CreatePHI(Type::Word, 2, "r.ret"); auto rRet = m_builder.CreatePHI(_type, 2, "r.ret");
rRet->addIncoming(r0, entryBB); rRet->addIncoming(r0, entryBB);
rRet->addIncoming(r1, loopBB); rRet->addIncoming(r1, loopBB);
auto ret = m_builder.CreateInsertValue(llvm::UndefValue::get(retType), qRet, 0, "ret0"); auto ret = m_builder.CreateInsertValue(llvm::UndefValue::get(retType), qRet, 0, "ret0");
ret = m_builder.CreateInsertValue(ret, rRet, 1, "ret"); ret = m_builder.CreateInsertValue(ret, rRet, 1, "ret");
m_builder.CreateRet(ret); m_builder.CreateRet(ret);
} }
return m_div; return func;
} }
llvm::Function* Arith256::getExpFunc() llvm::Function* Arith256::getExpFunc()
@ -204,6 +210,42 @@ llvm::Function* Arith256::getExpFunc()
return m_exp; return m_exp;
} }
llvm::Function* Arith256::getMulModFunc()
{
if (!m_mulmod)
{
llvm::Type* argTypes[] = {Type::Word, Type::Word, Type::Word};
m_mulmod = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "mulmod", getModule());
auto i512Ty = m_builder.getIntNTy(512);
llvm::Type* mul512ArgTypes[] = {Type::WordPtr, Type::WordPtr, i512Ty->getPointerTo()};
auto mul512 = llvm::Function::Create(llvm::FunctionType::get(Type::Void, mul512ArgTypes, false), llvm::Function::ExternalLinkage, "arith_mul512", getModule());
auto x = &m_mulmod->getArgumentList().front();
x->setName("x");
auto y = x->getNextNode();
y->setName("y");
auto mod = y->getNextNode();
mod->setName("mod");
InsertPointGuard guard{m_builder};
auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), {}, m_mulmod);
m_builder.SetInsertPoint(entryBB);
auto a1 = m_builder.CreateAlloca(Type::Word);
auto a2 = m_builder.CreateAlloca(Type::Word);
auto a3 = m_builder.CreateAlloca(i512Ty);
m_builder.CreateStore(x, a1);
m_builder.CreateStore(y, a2);
createCall(mul512, {a1, a2, a3});
auto p = m_builder.CreateLoad(a3, "p");
auto m = m_builder.CreateZExt(mod, i512Ty, "m");
auto d = createCall(getDivFunc(i512Ty), {p, m});
auto r = m_builder.CreateExtractValue(d, 1, "r");
m_builder.CreateRet(r);
}
return m_mulmod;
}
llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2) llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2)
@ -230,8 +272,8 @@ llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2)
std::pair<llvm::Value*, llvm::Value*> Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2) std::pair<llvm::Value*, llvm::Value*> Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2)
{ {
auto div = m_builder.CreateExtractValue(createCall(getDivFunc(), {_arg1, _arg2}), 0, "div"); auto div = m_builder.CreateExtractValue(createCall(getDivFunc(Type::Word), {_arg1, _arg2}), 0, "div");
auto mod = m_builder.CreateExtractValue(createCall(getDivFunc(), {_arg1, _arg2}), 1, "mod"); auto mod = m_builder.CreateExtractValue(createCall(getDivFunc(Type::Word), {_arg1, _arg2}), 1, "mod");
return std::make_pair(div, mod); return std::make_pair(div, mod);
} }
@ -270,39 +312,58 @@ llvm::Value* Arith256::addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Valu
llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3)
{ {
return ternaryOp(m_mulmod, _arg1, _arg2, _arg3); return createCall(getMulModFunc(), {_arg1, _arg2, _arg3});
} }
namespace namespace
{ {
using uint128 = __uint128_t; using uint128 = __uint128_t;
// uint128 add(uint128 a, uint128 b) { return a + b; }
// uint128 mul(uint128 a, uint128 b) { return a * b; }
//
// uint128 mulq(uint64_t x, uint64_t y)
// {
// return (uint128)x * (uint128)y;
// }
//
// uint128 addc(uint64_t x, uint64_t y)
// {
// return (uint128)x * (uint128)y;
// }
struct uint256 struct uint256
{ {
uint64_t lo; uint64_t lo = 0;
uint64_t mid; uint64_t mid = 0;
uint128 hi; uint128 hi = 0;
uint256(uint64_t lo, uint64_t mid, uint128 hi): lo(lo), mid(mid), hi(hi) {}
uint256(uint128 n)
{
*((uint128*)&lo) = n;
}
explicit operator uint128()
{
return *((uint128*)&lo);
}
uint256 operator|(uint256 a)
{
return {lo | a.lo, mid | a.mid, hi | a.hi};
}
uint256 operator+(uint256 a)
{
auto _lo = (uint128) lo + a.lo;
auto _mid = (uint128) mid + a.mid + (_lo >> 64);
auto _hi = hi + a.hi + (_mid >> 64);
return {(uint64_t)_lo, (uint64_t)_mid, _hi};
}
uint256 lo2hi()
{
hi = (uint128)*this;
lo = 0;
mid = 0;
return *this;
}
}; };
// uint256 add(uint256 x, uint256 y) struct uint512
// { {
// auto lo = (uint128) x.lo + y.lo; uint128 lo;
// auto mid = (uint128) x.mid + y.mid + (lo >> 64); uint128 mid;
// return {lo, mid, x.hi + y.hi + (mid >> 64)}; uint256 hi;
// } };
uint256 mul(uint256 x, uint256 y) uint256 mul(uint256 x, uint256 y)
{ {
@ -325,6 +386,23 @@ namespace
return {lo, (uint64_t)mid, hi}; return {lo, (uint64_t)mid, hi};
} }
uint512 mul512(uint256 x, uint256 y)
{
auto x_lo = (uint128) x;
auto y_lo = (uint128) y;
auto t1 = mul(x_lo, y_lo);
auto t2 = mul(x_lo, y.hi);
auto t3 = mul(x.hi, y_lo);
auto t4 = mul(x.hi, y.hi);
auto lo = (uint128) t1;
auto mid = (uint256) t1.hi + (uint128) t2 + (uint128) t3;
auto hi = (uint256)t2.hi + t3.hi + t4 + mid.hi;
return {lo, (uint128)mid, hi};
}
inline void mul(i256* x, i256* y) inline void mul(i256* x, i256* y)
{ {
auto a = (uint256*) x; auto a = (uint256*) x;
@ -376,21 +454,9 @@ extern "C"
*o_result = mul(*_arg1, *_arg2); *o_result = mul(*_arg1, *_arg2);
} }
EXPORT void arith_mulmod(i256* _arg1, i256* _arg2, i256* _arg3, i256* o_result) EXPORT void arith_mul512(uint256* _arg1, uint256* _arg2, uint512* o_result)
{ {
*o_result = {}; *o_result = mul512(*_arg1, *_arg2);
if (isZero(_arg3))
return;
mpz_t x{nLimbs, countLimbs(_arg1), reinterpret_cast<mp_limb_t*>(_arg1)};
mpz_t y{nLimbs, countLimbs(_arg2), reinterpret_cast<mp_limb_t*>(_arg2)};
mpz_t m{nLimbs, countLimbs(_arg3), reinterpret_cast<mp_limb_t*>(_arg3)};
mpz_t z{nLimbs, 0, reinterpret_cast<mp_limb_t*>(o_result)};
static mp_limb_t p_limbs[nLimbs * 2] = {};
static mpz_t p{nLimbs * 2, 0, &p_limbs[0]};
mpz_mul(p, x, y);
mpz_tdiv_r(z, p, m);
} }
EXPORT void arith_addmod(i256* _arg1, i256* _arg2, i256* _arg3, i256* o_result) EXPORT void arith_addmod(i256* _arg1, i256* _arg2, i256* _arg3, i256* o_result)

6
libevmjit/Arith256.h

@ -24,18 +24,20 @@ public:
void debug(llvm::Value* _value, char _c); void debug(llvm::Value* _value, char _c);
private: private:
llvm::Function* getDivFunc(); llvm::Function* getDivFunc(llvm::Type* _type);
llvm::Function* getExpFunc(); llvm::Function* getExpFunc();
llvm::Function* getMulModFunc();
llvm::Value* binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2); llvm::Value* binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2);
llvm::Value* ternaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3); llvm::Value* ternaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3);
llvm::Function* m_mul; llvm::Function* m_mul;
llvm::Function* m_mulmod;
llvm::Function* m_addmod; llvm::Function* m_addmod;
llvm::Function* m_div = nullptr; llvm::Function* m_div = nullptr;
llvm::Function* m_div512 = nullptr;
llvm::Function* m_exp = nullptr; llvm::Function* m_exp = nullptr;
llvm::Function* m_mulmod = nullptr;
llvm::Function* m_debug = nullptr; llvm::Function* m_debug = nullptr;
llvm::Value* m_arg1; llvm::Value* m_arg1;

Loading…
Cancel
Save