Browse Source

Move mul function to LLVM

cl-refactor
Paweł Bylica 10 years ago
parent
commit
d58f35343b
  1. 94
      libevmjit/Arith256.cpp
  2. 11
      libevmjit/Arith256.h

94
libevmjit/Arith256.cpp

@ -5,6 +5,7 @@
#include <llvm/IR/Function.h> #include <llvm/IR/Function.h>
#include <llvm/IR/IntrinsicInst.h> #include <llvm/IR/IntrinsicInst.h>
#include <iostream> #include <iostream>
namespace dev namespace dev
@ -16,20 +17,7 @@ namespace jit
Arith256::Arith256(llvm::IRBuilder<>& _builder) : Arith256::Arith256(llvm::IRBuilder<>& _builder) :
CompilerHelper(_builder) CompilerHelper(_builder)
{ {}
using namespace llvm;
m_result = m_builder.CreateAlloca(Type::Word, nullptr, "arith.result");
m_arg1 = m_builder.CreateAlloca(Type::Word, nullptr, "arith.arg1");
m_arg2 = m_builder.CreateAlloca(Type::Word, nullptr, "arith.arg2");
m_arg3 = m_builder.CreateAlloca(Type::Word, nullptr, "arith.arg3");
using Linkage = GlobalValue::LinkageTypes;
llvm::Type* arg2Types[] = {Type::WordPtr, Type::WordPtr, Type::WordPtr};
m_mul = Function::Create(FunctionType::get(Type::Void, arg2Types, false), Linkage::ExternalLinkage, "arith_mul", getModule());
}
void Arith256::debug(llvm::Value* _value, char _c) void Arith256::debug(llvm::Value* _value, char _c)
{ {
@ -41,6 +29,54 @@ void Arith256::debug(llvm::Value* _value, char _c)
createCall(m_debug, {_value, m_builder.getInt8(_c)}); createCall(m_debug, {_value, m_builder.getInt8(_c)});
} }
llvm::Function* Arith256::getMulFunc()
{
auto& func = m_mul;
if (!func)
{
llvm::Type* argTypes[] = {Type::Word, Type::Word};
func = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "mul", getModule());
auto x = &func->getArgumentList().front();
x->setName("x");
auto y = x->getNextNode();
y->setName("y");
InsertPointGuard guard{m_builder};
auto bb = llvm::BasicBlock::Create(m_builder.getContext(), {}, func);
m_builder.SetInsertPoint(bb);
auto i64 = Type::Size;
auto i128 = m_builder.getIntNTy(128);
auto i256 = Type::Word;
auto x_lo = m_builder.CreateTrunc(x, i64, "x.lo");
auto y_lo = m_builder.CreateTrunc(y, i64, "y.lo");
auto x_mi = m_builder.CreateTrunc(m_builder.CreateLShr(x, Constant::get(64)), i64);
auto y_mi = m_builder.CreateTrunc(m_builder.CreateLShr(y, Constant::get(64)), i64);
auto x_hi = m_builder.CreateTrunc(m_builder.CreateLShr(x, Constant::get(128)), i128);
auto y_hi = m_builder.CreateTrunc(m_builder.CreateLShr(y, Constant::get(128)), i128);
auto t1 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), m_builder.CreateZExt(y_lo, i128));
auto t2 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), m_builder.CreateZExt(y_mi, i128));
auto t3 = m_builder.CreateMul(m_builder.CreateZExt(x_lo, i128), y_hi);
auto t4 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), m_builder.CreateZExt(y_lo, i128));
auto t5 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), m_builder.CreateZExt(y_mi, i128));
auto t6 = m_builder.CreateMul(m_builder.CreateZExt(x_mi, i128), y_hi);
auto t7 = m_builder.CreateMul(x_hi, m_builder.CreateZExt(y_lo, i128));
auto t8 = m_builder.CreateMul(x_hi, m_builder.CreateZExt(y_mi, i128));
auto p = m_builder.CreateZExt(t1, i256);
p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t2, i256), Constant::get(64)));
p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t3, i256), Constant::get(128)));
p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t4, i256), Constant::get(64)));
p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t5, i256), Constant::get(128)));
p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t6, i256), Constant::get(192)));
p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t7, i256), Constant::get(128)));
p = m_builder.CreateAdd(p, m_builder.CreateShl(m_builder.CreateZExt(t8, i256), Constant::get(192)));
m_builder.CreateRet(p);
}
return m_mul;
}
llvm::Function* Arith256::getDivFunc(llvm::Type* _type) llvm::Function* Arith256::getDivFunc(llvm::Type* _type)
{ {
auto& func = _type == Type::Word ? m_div : m_div512; auto& func = _type == Type::Word ? m_div : m_div512;
@ -135,7 +171,7 @@ llvm::Function* Arith256::getExpFunc()
if (!m_exp) if (!m_exp)
{ {
llvm::Type* argTypes[] = {Type::Word, Type::Word}; llvm::Type* argTypes[] = {Type::Word, Type::Word};
m_exp = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "arith.exp", getModule()); m_exp = llvm::Function::Create(llvm::FunctionType::get(Type::Word, argTypes, false), llvm::Function::PrivateLinkage, "exp", getModule());
auto base = &m_exp->getArgumentList().front(); auto base = &m_exp->getArgumentList().front();
base->setName("base"); base->setName("base");
@ -159,9 +195,6 @@ llvm::Function* Arith256::getExpFunc()
auto returnBB = llvm::BasicBlock::Create(m_builder.getContext(), "Return", m_exp); auto returnBB = llvm::BasicBlock::Create(m_builder.getContext(), "Return", m_exp);
m_builder.SetInsertPoint(entryBB); m_builder.SetInsertPoint(entryBB);
auto a1 = m_builder.CreateAlloca(Type::Word, nullptr, "a1");
auto a2 = m_builder.CreateAlloca(Type::Word, nullptr, "a2");
auto a3 = m_builder.CreateAlloca(Type::Word, nullptr, "a3");
m_builder.CreateBr(headerBB); m_builder.CreateBr(headerBB);
m_builder.SetInsertPoint(headerBB); m_builder.SetInsertPoint(headerBB);
@ -176,20 +209,14 @@ llvm::Function* Arith256::getExpFunc()
m_builder.CreateCondBr(eOdd, updateBB, continueBB); m_builder.CreateCondBr(eOdd, updateBB, continueBB);
m_builder.SetInsertPoint(updateBB); m_builder.SetInsertPoint(updateBB);
m_builder.CreateStore(r, a1); auto r0 = createCall(getMulFunc(), {r, b});
m_builder.CreateStore(b, a2);
createCall(m_mul, {a1, a2, a3});
auto r0 = m_builder.CreateLoad(a3, "r0");
m_builder.CreateBr(continueBB); m_builder.CreateBr(continueBB);
m_builder.SetInsertPoint(continueBB); m_builder.SetInsertPoint(continueBB);
auto r1 = m_builder.CreatePHI(Type::Word, 2, "r1"); auto r1 = m_builder.CreatePHI(Type::Word, 2, "r1");
r1->addIncoming(r, bodyBB); r1->addIncoming(r, bodyBB);
r1->addIncoming(r0, updateBB); r1->addIncoming(r0, updateBB);
m_builder.CreateStore(b, a1); auto b1 = createCall(getMulFunc(), {b, b});
m_builder.CreateStore(b, a2);
createCall(m_mul, {a1, a2, a3});
auto b1 = m_builder.CreateLoad(a3, "b1");
auto e1 = m_builder.CreateLShr(e, Constant::get(1), "e1"); auto e1 = m_builder.CreateLShr(e, Constant::get(1), "e1");
m_builder.CreateBr(headerBB); m_builder.CreateBr(headerBB);
@ -273,17 +300,9 @@ llvm::Function* Arith256::getMulModFunc()
return m_mulmod; return m_mulmod;
} }
llvm::Value* Arith256::binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2)
{
m_builder.CreateStore(_arg1, m_arg1);
m_builder.CreateStore(_arg2, m_arg2);
m_builder.CreateCall3(_op, m_arg1, m_arg2, m_result);
return m_builder.CreateLoad(m_result);
}
llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2) llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2)
{ {
return binaryOp(m_mul, _arg1, _arg2); return createCall(getMulFunc(), {_arg1, _arg2});
} }
std::pair<llvm::Value*, llvm::Value*> Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2) std::pair<llvm::Value*, llvm::Value*> Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2)
@ -496,11 +515,6 @@ extern "C"
std::cerr << "DEBUG " << z << ": " << d << c << b << a << std::endl; std::cerr << "DEBUG " << z << ": " << d << c << b << a << std::endl;
} }
EXPORT void arith_mul(uint256* _arg1, uint256* _arg2, uint256* o_result)
{
*o_result = mul(*_arg1, *_arg2);
}
EXPORT void arith_mul512(uint256* _arg1, uint256* _arg2, uint512* o_result) EXPORT void arith_mul512(uint256* _arg1, uint256* _arg2, uint512* o_result)
{ {
*o_result = mul512(*_arg1, *_arg2); *o_result = mul512(*_arg1, *_arg2);

11
libevmjit/Arith256.h

@ -24,26 +24,19 @@ public:
void debug(llvm::Value* _value, char _c); void debug(llvm::Value* _value, char _c);
private: private:
llvm::Function* getMulFunc();
llvm::Function* getDivFunc(llvm::Type* _type); llvm::Function* getDivFunc(llvm::Type* _type);
llvm::Function* getExpFunc(); llvm::Function* getExpFunc();
llvm::Function* getAddModFunc(); llvm::Function* getAddModFunc();
llvm::Function* getMulModFunc(); llvm::Function* getMulModFunc();
llvm::Value* binaryOp(llvm::Function* _op, llvm::Value* _arg1, llvm::Value* _arg2); llvm::Function* m_mul = nullptr;
llvm::Function* m_mul;
llvm::Function* m_div = nullptr; llvm::Function* m_div = nullptr;
llvm::Function* m_div512 = nullptr; llvm::Function* m_div512 = nullptr;
llvm::Function* m_exp = nullptr; llvm::Function* m_exp = nullptr;
llvm::Function* m_addmod = nullptr; llvm::Function* m_addmod = nullptr;
llvm::Function* m_mulmod = nullptr; llvm::Function* m_mulmod = nullptr;
llvm::Function* m_debug = nullptr; llvm::Function* m_debug = nullptr;
llvm::Value* m_arg1;
llvm::Value* m_arg2;
llvm::Value* m_arg3;
llvm::Value* m_result;
}; };

Loading…
Cancel
Save