Browse Source

New addmod algorithm

cl-refactor
Paweł Bylica 10 years ago
parent
commit
259a06e72b
  1. 96
      evmjit/libevmjit/Arith256.cpp
  2. 4
      evmjit/libevmjit/Arith256.h

96
evmjit/libevmjit/Arith256.cpp

@ -29,10 +29,8 @@ Arith256::Arith256(llvm::IRBuilder<>& _builder) :
using Linkage = GlobalValue::LinkageTypes; using Linkage = GlobalValue::LinkageTypes;
llvm::Type* arg2Types[] = {Type::WordPtr, Type::WordPtr, Type::WordPtr}; llvm::Type* arg2Types[] = {Type::WordPtr, Type::WordPtr, Type::WordPtr};
llvm::Type* arg3Types[] = {Type::WordPtr, Type::WordPtr, Type::WordPtr, Type::WordPtr};
m_mul = Function::Create(FunctionType::get(Type::Void, arg2Types, false), Linkage::ExternalLinkage, "arith_mul", getModule()); m_mul = Function::Create(FunctionType::get(Type::Void, arg2Types, false), Linkage::ExternalLinkage, "arith_mul", getModule());
m_addmod = Function::Create(FunctionType::get(Type::Void, arg3Types, false), Linkage::ExternalLinkage, "arith_addmod", getModule());
} }
void Arith256::debug(llvm::Value* _value, char _c) void Arith256::debug(llvm::Value* _value, char _c)
@ -210,6 +208,36 @@ llvm::Function* Arith256::getExpFunc()
return m_exp; return m_exp;
} }
llvm::Function* Arith256::getAddModFunc()
{
if (!m_addmod)
{
auto i512Ty = m_builder.getIntNTy(512);
llvm::Type* argTypes[] = {Type::Word, Type::Word, Type::Word};
m_addmod = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "addmod", getModule());
auto x = &m_addmod->getArgumentList().front();
x->setName("x");
auto y = x->getNextNode();
y->setName("y");
auto mod = y->getNextNode();
mod->setName("m");
InsertPointGuard guard{m_builder};
auto entryBB = llvm::BasicBlock::Create(m_builder.getContext(), {}, m_addmod);
m_builder.SetInsertPoint(entryBB);
auto x512 = m_builder.CreateZExt(x, i512Ty, "x512");
auto y512 = m_builder.CreateZExt(y, i512Ty, "y512");
auto m512 = m_builder.CreateZExt(mod, i512Ty, "m512");
auto s = m_builder.CreateAdd(x512, y512, "s");
auto d = createCall(getDivFunc(i512Ty), {s, m512});
auto r = m_builder.CreateExtractValue(d, 1, "r");
m_builder.CreateRet(r);
}
return m_addmod;
}
llvm::Function* Arith256::getMulModFunc() llvm::Function* Arith256::getMulModFunc()
{ {
if (!m_mulmod) if (!m_mulmod)
@ -247,7 +275,6 @@ llvm::Function* Arith256::getMulModFunc()
return m_mulmod; return m_mulmod;
} }
llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2) llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2)
{ {
m_builder.CreateStore(_arg1, m_arg1); m_builder.CreateStore(_arg1, m_arg1);
@ -256,15 +283,6 @@ llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::V
return m_builder.CreateLoad(m_result); return m_builder.CreateLoad(m_result);
} }
llvm::Value* Arith256::ternaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3)
{
m_builder.CreateStore(_arg1, m_arg1);
m_builder.CreateStore(_arg2, m_arg2);
m_builder.CreateStore(_arg3, m_arg3);
m_builder.CreateCall4(_op, m_arg1, m_arg2, m_arg3, m_result);
return m_builder.CreateLoad(m_result);
}
llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2) llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2)
{ {
return binaryOp(m_mul, _arg1, _arg2); return binaryOp(m_mul, _arg1, _arg2);
@ -307,7 +325,7 @@ llvm::Value* Arith256::exp(llvm::Value* _arg1, llvm::Value* _arg2)
llvm::Value* Arith256::addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) llvm::Value* Arith256::addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3)
{ {
return ternaryOp(m_addmod, _arg1, _arg2, _arg3); return createCall(getAddModFunc(), {_arg1, _arg2, _arg3});
} }
llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3) llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3)
@ -402,46 +420,14 @@ namespace
return {lo, (uint128)mid, hi}; return {lo, (uint128)mid, hi};
} }
inline void mul(i256* x, i256* y)
{
auto a = (uint256*) x;
auto b = (uint256*) y;
*a = mul(*a, *b);
} }
bool isZero(i256 const* _n)
{
return _n->a == 0 && _n->b == 0 && _n->c == 0 && _n->d == 0;
} }
const auto nLimbs = sizeof(i256) / sizeof(mp_limb_t);
int countLimbs(i256 const* _n)
{
static const auto limbsInWord = sizeof(_n->a) / sizeof(mp_limb_t);
static_assert(limbsInWord == 1, "E?");
int l = nLimbs;
if (_n->d != 0) return l;
l -= limbsInWord;
if (_n->c != 0) return l;
l -= limbsInWord;
if (_n->b != 0) return l;
l -= limbsInWord;
if (_n->a != 0) return l;
return 0;
} }
} }
}
}
}
extern "C" extern "C"
{ {
using namespace dev::eth::jit; using namespace dev::eth::jit;
EXPORT void debug(uint64_t a, uint64_t b, uint64_t c, uint64_t d, char z) EXPORT void debug(uint64_t a, uint64_t b, uint64_t c, uint64_t d, char z)
@ -458,24 +444,4 @@ extern "C"
{ {
*o_result = mul512(*_arg1, *_arg2); *o_result = mul512(*_arg1, *_arg2);
} }
EXPORT void arith_addmod(i256* _arg1, i256* _arg2, i256* _arg3, i256* o_result)
{
*o_result = {};
if (isZero(_arg3))
return;
mpz_t x{nLimbs, countLimbs(_arg1), reinterpret_cast<mp_limb_t*>(_arg1)};
mpz_t y{nLimbs, countLimbs(_arg2), reinterpret_cast<mp_limb_t*>(_arg2)};
mpz_t m{nLimbs, countLimbs(_arg3), reinterpret_cast<mp_limb_t*>(_arg3)};
mpz_t z{nLimbs, 0, reinterpret_cast<mp_limb_t*>(o_result)};
static mp_limb_t s_limbs[nLimbs + 1] = {};
static mpz_t s{nLimbs + 1, 0, &s_limbs[0]};
mpz_add(s, x, y);
mpz_tdiv_r(z, s, m);
}
} }

4
evmjit/libevmjit/Arith256.h

@ -26,17 +26,17 @@ public:
private: private:
llvm::Function* getDivFunc(llvm::Type* _type); llvm::Function* getDivFunc(llvm::Type* _type);
llvm::Function* getExpFunc(); llvm::Function* getExpFunc();
llvm::Function* getAddModFunc();
llvm::Function* getMulModFunc(); llvm::Function* getMulModFunc();
llvm::Value* binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2); llvm::Value* binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2);
llvm::Value* ternaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3);
llvm::Function* m_mul; llvm::Function* m_mul;
llvm::Function* m_addmod;
llvm::Function* m_div = nullptr; llvm::Function* m_div = nullptr;
llvm::Function* m_div512 = nullptr; llvm::Function* m_div512 = nullptr;
llvm::Function* m_exp = nullptr; llvm::Function* m_exp = nullptr;
llvm::Function* m_addmod = nullptr;
llvm::Function* m_mulmod = nullptr; llvm::Function* m_mulmod = nullptr;
llvm::Function* m_debug = nullptr; llvm::Function* m_debug = nullptr;

Loading…
Cancel
Save