From 6b7787cd2bc39d396d7f6ab2036256ba6968bad7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pawe=C5=82=20Bylica?= <pawel.bylica@imapp.pl>
Date: Tue, 3 Mar 2015 15:26:16 +0100
Subject: [PATCH] Ad-hoc constant fold arithmetic instructions

---
 libevmjit/Arith256.cpp | 90 +++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 88 insertions(+), 2 deletions(-)

diff --git a/libevmjit/Arith256.cpp b/libevmjit/Arith256.cpp
index 220aa6f05..47701462c 100644
--- a/libevmjit/Arith256.cpp
+++ b/libevmjit/Arith256.cpp
@@ -408,18 +408,49 @@ llvm::Function* Arith256::getMulModFunc()
 
 llvm::Value* Arith256::mul(llvm::Value* _arg1, llvm::Value* _arg2)
 {
+	if (auto c1 = llvm::dyn_cast<llvm::ConstantInt>(_arg1))
+	{
+		if (auto c2 = llvm::dyn_cast<llvm::ConstantInt>(_arg2))
+			return Constant::get(c1->getValue() * c2->getValue());
+	}
+
 	return createCall(getMulFunc(), {_arg1, _arg2});
 }
 
 std::pair<llvm::Value*, llvm::Value*> Arith256::div(llvm::Value* _arg1, llvm::Value* _arg2)
 {
-	auto div =  m_builder.CreateExtractValue(createCall(getDivFunc(Type::Word), {_arg1, _arg2}), 0, "div");
-	auto mod =  m_builder.CreateExtractValue(createCall(getDivFunc(Type::Word), {_arg1, _arg2}), 1, "mod");
+	if (auto c1 = llvm::dyn_cast<llvm::ConstantInt>(_arg1))
+	{
+		if (auto c2 = llvm::dyn_cast<llvm::ConstantInt>(_arg2))
+		{
+			if (!c2->getValue())
+				return std::make_pair(Constant::get(0), Constant::get(0));
+			auto div = Constant::get(c1->getValue().udiv(c2->getValue()));
+			auto mod = Constant::get(c1->getValue().urem(c2->getValue()));
+			return std::make_pair(div, mod);
+		}
+	}
+
+	auto r = createCall(getDivFunc(Type::Word), {_arg1, _arg2});
+	auto div =  m_builder.CreateExtractValue(r, 0, "div");
+	auto mod =  m_builder.CreateExtractValue(r, 1, "mod");
 	return std::make_pair(div, mod);
 }
 
 std::pair<llvm::Value*, llvm::Value*> Arith256::sdiv(llvm::Value* _x, llvm::Value* _y)
 {
+	if (auto c1 = llvm::dyn_cast<llvm::ConstantInt>(_x))
+	{
+		if (auto c2 = llvm::dyn_cast<llvm::ConstantInt>(_y))
+		{
+			if (!c2->getValue())
+				return std::make_pair(Constant::get(0), Constant::get(0));
+			auto div = Constant::get(c1->getValue().sdiv(c2->getValue()));
+			auto mod = Constant::get(c1->getValue().srem(c2->getValue()));
+			return std::make_pair(div, mod);
+		}
+	}
+
 	auto xIsNeg = m_builder.CreateICmpSLT(_x, Constant::get(0));
 	auto xNeg = m_builder.CreateSub(Constant::get(0), _x);
 	auto xAbs = m_builder.CreateSelect(xIsNeg, xNeg, _x);
@@ -443,16 +474,71 @@ std::pair<llvm::Value*, llvm::Value*> Arith256::sdiv(llvm::Value* _x, llvm::Valu
 
 llvm::Value* Arith256::exp(llvm::Value* _arg1, llvm::Value* _arg2)
 {
+	//	while (e != 0) {
+	//		if (e % 2 == 1)
+	//			r *= b;
+	//		b *= b;
+	//		e /= 2;
+	//	}
+
+	if (auto c1 = llvm::dyn_cast<llvm::ConstantInt>(_arg1))
+	{
+		if (auto c2 = llvm::dyn_cast<llvm::ConstantInt>(_arg2))
+		{
+			auto b = c1->getValue();
+			auto e = c2->getValue();
+			auto r = llvm::APInt{256, 1};
+			while (e != 0)
+			{
+				if (e[0])
+					r *= b;
+				b *= b;
+				e = e.lshr(1);
+			}
+			return Constant::get(r);
+		}
+	}
+
 	return createCall(getExpFunc(), {_arg1, _arg2});
 }
 
 llvm::Value* Arith256::addmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3)
 {
+	if (auto c1 = llvm::dyn_cast<llvm::ConstantInt>(_arg1))
+	{
+		if (auto c2 = llvm::dyn_cast<llvm::ConstantInt>(_arg2))
+		{
+			if (auto c3 = llvm::dyn_cast<llvm::ConstantInt>(_arg3))
+			{
+				if (!c3->getValue())
+					return Constant::get(0);
+				auto s = c1->getValue().zext(256+64) + c2->getValue().zext(256+64);
+				auto r = s.urem(c3->getValue().zext(256+64)).trunc(256);
+				return Constant::get(r);
+			}
+		}
+	}
+
 	return createCall(getAddModFunc(), {_arg1, _arg2, _arg3});
 }
 
 llvm::Value* Arith256::mulmod(llvm::Value* _arg1, llvm::Value* _arg2, llvm::Value* _arg3)
 {
+	if (auto c1 = llvm::dyn_cast<llvm::ConstantInt>(_arg1))
+	{
+		if (auto c2 = llvm::dyn_cast<llvm::ConstantInt>(_arg2))
+		{
+			if (auto c3 = llvm::dyn_cast<llvm::ConstantInt>(_arg3))
+			{
+				if (!c3->getValue())
+					return Constant::get(0);
+				auto p = c1->getValue().zext(512) * c2->getValue().zext(512);
+				auto r = p.urem(c3->getValue().zext(512)).trunc(256);
+				return Constant::get(r);
+			}
+		}
+	}
+
 	return createCall(getMulModFunc(), {_arg1, _arg2, _arg3});
 }