diff --git a/libevmjit/BasicBlock.cpp b/libevmjit/BasicBlock.cpp index 93d73b991..a41743d0b 100644 --- a/libevmjit/BasicBlock.cpp +++ b/libevmjit/BasicBlock.cpp @@ -49,6 +49,7 @@ void BasicBlock::LocalStack::push(llvm::Value* _value) assert(_value->getType() == Type::Word); m_bblock.m_currentStack.push_back(_value); m_bblock.m_tosOffset += 1; + m_maxSize = std::max(m_maxSize, m_bblock.m_currentStack.size()); } llvm::Value* BasicBlock::LocalStack::pop() diff --git a/libevmjit/BasicBlock.h b/libevmjit/BasicBlock.h index 7469b7b69..5e19235a7 100644 --- a/libevmjit/BasicBlock.h +++ b/libevmjit/BasicBlock.h @@ -33,6 +33,9 @@ public: /// @param _index Index of value to be swaped. Must be > 0. void swap(size_t _index); + size_t getMaxSize() const { return m_maxSize; } + int getDiff() const { return m_bblock.m_tosOffset; } + private: LocalStack(BasicBlock& _owner); LocalStack(LocalStack const&) = delete; @@ -49,6 +52,7 @@ public: private: BasicBlock& m_bblock; + size_t m_maxSize = 0; ///< Max size reached by the stack. }; explicit BasicBlock(instr_idx _firstInstrIdx, code_iterator _begin, code_iterator _end, llvm::Function* _mainFunc, llvm::IRBuilder<>& _builder, bool isJumpDest); diff --git a/libevmjit/Compiler.cpp b/libevmjit/Compiler.cpp index be3670874..1e1f8fe93 100644 --- a/libevmjit/Compiler.cpp +++ b/libevmjit/Compiler.cpp @@ -838,6 +838,9 @@ void Compiler::compileBasicBlock(BasicBlock& _basicBlock, RuntimeManager& _runti // Block may have no terminator if the next instruction is a jump destination. if (!_basicBlock.llvm()->getTerminator()) m_builder.CreateBr(_nextBasicBlock); + + m_builder.SetInsertPoint(_basicBlock.llvm()->getFirstNonPHI()); + _runtimeManager.checkStackLimit(_basicBlock.localStack().getMaxSize(), _basicBlock.localStack().getDiff()); } diff --git a/libevmjit/RuntimeManager.cpp b/libevmjit/RuntimeManager.cpp index d1ccaea8a..f4db4d052 100644 --- a/libevmjit/RuntimeManager.cpp +++ b/libevmjit/RuntimeManager.cpp @@ -78,7 +78,7 @@ llvm::Twine getName(RuntimeData::Index _index) case RuntimeData::CodeSize: return "code"; case RuntimeData::CallDataSize: return "callDataSize"; case RuntimeData::Gas: return "gas"; - case RuntimeData::Number: return "number"; + case RuntimeData::Number: return "number"; case RuntimeData::Timestamp: return "timestamp"; } } @@ -101,6 +101,48 @@ RuntimeManager::RuntimeManager(llvm::IRBuilder<>& _builder, code_iterator _codeB assert(m_memPtr->getType() == Array::getType()->getPointerTo()); m_envPtr = m_builder.CreateLoad(m_builder.CreateStructGEP(rtPtr, 1), "env"); assert(m_envPtr->getType() == Type::EnvPtr); + + m_stackSize = m_builder.CreateAlloca(Type::Size, nullptr, "stackSize"); + m_builder.CreateStore(m_builder.getInt64(0), m_stackSize); + + llvm::Type* checkStackLimitArgs[] = {Type::Size->getPointerTo(), Type::Size, Type::Size, Type::BytePtr}; + m_checkStackLimit = llvm::Function::Create(llvm::FunctionType::get(Type::Void, checkStackLimitArgs, false), llvm::Function::PrivateLinkage, "stack.checkSize", getModule()); + m_checkStackLimit->setDoesNotThrow(); + m_checkStackLimit->setDoesNotCapture(1); + + auto checkBB = llvm::BasicBlock::Create(_builder.getContext(), "Check", m_checkStackLimit); + auto updateBB = llvm::BasicBlock::Create(_builder.getContext(), "Update", m_checkStackLimit); + auto outOfStackBB = llvm::BasicBlock::Create(_builder.getContext(), "OutOfStack", m_checkStackLimit); + + auto currSizePtr = &m_checkStackLimit->getArgumentList().front(); + currSizePtr->setName("currSize"); + auto max = currSizePtr->getNextNode(); + max->setName("max"); + auto diff = max->getNextNode(); + diff->setName("diff"); + auto jmpBuf = diff->getNextNode(); + jmpBuf->setName("jmpBuf"); + + InsertPointGuard guard{m_builder}; + m_builder.SetInsertPoint(checkBB); + auto currSize = m_builder.CreateLoad(currSizePtr, "cur"); + auto maxSize = m_builder.CreateNUWAdd(currSize, max, "maxSize"); + auto ok = m_builder.CreateICmpULT(maxSize, m_builder.getInt64(1024), "ok"); + m_builder.CreateCondBr(ok, updateBB, outOfStackBB, Type::expectTrue); + + m_builder.SetInsertPoint(updateBB); + auto newSize = m_builder.CreateNSWAdd(currSize, diff); + m_builder.CreateStore(newSize, currSizePtr); + m_builder.CreateRetVoid(); + + m_builder.SetInsertPoint(outOfStackBB); + abort(jmpBuf); + m_builder.CreateUnreachable(); +} + +void RuntimeManager::checkStackLimit(size_t _max, int _diff) +{ + createCall(m_checkStackLimit, {m_stackSize, m_builder.getInt64(_max), m_builder.getInt64(_diff), getJmpBuf()}); } llvm::Value* RuntimeManager::getRuntimePtr() diff --git a/libevmjit/RuntimeManager.h b/libevmjit/RuntimeManager.h index eb8dadf6f..c430a1098 100644 --- a/libevmjit/RuntimeManager.h +++ b/libevmjit/RuntimeManager.h @@ -48,6 +48,8 @@ public: static llvm::StructType* getRuntimeType(); static llvm::StructType* getRuntimeDataType(); + void checkStackLimit(size_t _max, int _diff); + private: llvm::Value* getPtr(RuntimeData::Index _index); void set(RuntimeData::Index _index, llvm::Value* _value); @@ -59,6 +61,9 @@ private: llvm::Value* m_memPtr = nullptr; llvm::Value* m_envPtr = nullptr; + llvm::Value* m_stackSize = nullptr; + llvm::Function* m_checkStackLimit = nullptr; + code_iterator m_codeBegin = {}; code_iterator m_codeEnd = {};