diff --git a/libevmasm/CommonSubexpressionEliminator.cpp b/libevmasm/CommonSubexpressionEliminator.cpp index e369c9dbc..7564fcd99 100644 --- a/libevmasm/CommonSubexpressionEliminator.cpp +++ b/libevmasm/CommonSubexpressionEliminator.cpp @@ -153,7 +153,9 @@ AssemblyItems CSECodeGenerator::generateCode( assertThrow(!m_classPositions[targetItem.second].empty(), OptimizerException, ""); if (m_classPositions[targetItem.second].count(targetItem.first)) continue; - SourceLocation const& location = m_expressionClasses.representative(targetItem.second).item->getLocation(); + SourceLocation location; + if (m_expressionClasses.representative(targetItem.second).item) + location = m_expressionClasses.representative(targetItem.second).item->getLocation(); int position = classElementPosition(targetItem.second); if (position < targetItem.first) // it is already at its target, we need another copy @@ -197,7 +199,9 @@ void CSECodeGenerator::addDependencies(Id _c) addDependencies(argument); m_neededBy.insert(make_pair(argument, _c)); } - if (expr.item->type() == Operation && ( + if ( + expr.item && + expr.item->type() == Operation && ( expr.item->instruction() == Instruction::SLOAD || expr.item->instruction() == Instruction::MLOAD || expr.item->instruction() == Instruction::SHA3 @@ -288,6 +292,7 @@ void CSECodeGenerator::generateClassElement(Id _c, bool _allowSequenced) OptimizerException, "Sequence constrained operation requested out of sequence." ); + assertThrow(expr.item, OptimizerException, "Non-generated expression without item."); vector const& arguments = expr.arguments; for (Id arg: boost::adaptors::reverse(arguments)) generateClassElement(arg); diff --git a/libevmasm/ControlFlowGraph.cpp b/libevmasm/ControlFlowGraph.cpp index cc68b2af8..3566bdb17 100644 --- a/libevmasm/ControlFlowGraph.cpp +++ b/libevmasm/ControlFlowGraph.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -217,7 +218,6 @@ void ControlFlowGraph::gatherKnowledge() // @todo actually we know that memory is filled with zeros at the beginning, // we could make use of that. KnownStatePointer emptyState = make_shared(); - ExpressionClasses& expr = emptyState->expressionClasses(); bool unknownJumpEncountered = false; vector> workQueue({make_pair(BlockId::initial(), emptyState->copy())}); @@ -238,8 +238,6 @@ void ControlFlowGraph::gatherKnowledge() } block.startState = state->copy(); - //@todo we might know the return address for the first pass, but not anymore for the second, - // -> store knowledge about tags as a union. // Feed all items except for the final jump yet because it will erase the target tag. unsigned pc = block.begin; @@ -254,22 +252,29 @@ void ControlFlowGraph::gatherKnowledge() assertThrow(block.begin <= pc && pc == block.end - 1, OptimizerException, ""); //@todo in the case of JUMPI, add knowledge about the condition to the state // (for both values of the condition) - BlockId nextBlock = expressionClassToBlockId( - state->stackElement(state->stackHeight(), SourceLocation()), - expr + set tags = state->tagsInExpression( + state->stackElement(state->stackHeight(), SourceLocation()) ); state->feedItem(m_items.at(pc++)); - if (nextBlock) - workQueue.push_back(make_pair(nextBlock, state->copy())); - else if (!unknownJumpEncountered) + + if (tags.empty() || std::any_of(tags.begin(), tags.end(), [&](u256 const& _tag) + { + return !m_blocks.count(BlockId(_tag)); + })) { - // We do not know where this jump goes, so we have to reset the states of all - // JUMPDESTs. - unknownJumpEncountered = true; - for (auto const& it: m_blocks) - if (it.second.begin < it.second.end && m_items[it.second.begin].type() == Tag) - workQueue.push_back(make_pair(it.first, emptyState->copy())); + if (!unknownJumpEncountered) + { + // We do not know the target of this jump, so we have to reset the states of all + // JUMPDESTs. + unknownJumpEncountered = true; + for (auto const& it: m_blocks) + if (it.second.begin < it.second.end && m_items[it.second.begin].type() == Tag) + workQueue.push_back(make_pair(it.first, emptyState->copy())); + } } + else + for (auto tag: tags) + workQueue.push_back(make_pair(BlockId(tag), state->copy())); } else if (block.begin <= pc && pc < block.end) state->feedItem(m_items.at(pc++)); @@ -329,7 +334,11 @@ BasicBlocks ControlFlowGraph::rebuildCode() if (previousHandedOver && !pushes[blockId] && m_items[block.begin].type() == Tag) ++block.begin; if (block.begin < block.end) + { blocks.push_back(block); + blocks.back().startState->clearTagUnions(); + blocks.back().endState->clearTagUnions(); + } previousHandedOver = (block.endType == BasicBlock::EndType::HANDOVER); } } @@ -337,18 +346,6 @@ BasicBlocks ControlFlowGraph::rebuildCode() return blocks; } -BlockId ControlFlowGraph::expressionClassToBlockId( - ExpressionClasses::Id _id, - ExpressionClasses& _exprClasses -) -{ - ExpressionClasses::Expression expr = _exprClasses.representative(_id); - if (expr.item && expr.item->type() == PushTag) - return BlockId(expr.item->data()); - else - return BlockId::invalid(); -} - BlockId ControlFlowGraph::generateNewId() { BlockId id = BlockId(++m_lastUsedId); diff --git a/libevmasm/ControlFlowGraph.h b/libevmasm/ControlFlowGraph.h index 3366dc45f..4480ba491 100644 --- a/libevmasm/ControlFlowGraph.h +++ b/libevmasm/ControlFlowGraph.h @@ -108,10 +108,6 @@ private: void setPrevLinks(); BasicBlocks rebuildCode(); - /// @returns the corresponding BlockId if _id is a pushed jump tag, - /// and an invalid BlockId otherwise. - BlockId expressionClassToBlockId(ExpressionClasses::Id _id, ExpressionClasses& _exprClasses); - BlockId generateNewId(); unsigned m_lastUsedId = 0; diff --git a/libevmasm/ExpressionClasses.cpp b/libevmasm/ExpressionClasses.cpp index cfbeba7fa..81adc0dbb 100644 --- a/libevmasm/ExpressionClasses.cpp +++ b/libevmasm/ExpressionClasses.cpp @@ -82,6 +82,16 @@ ExpressionClasses::Id ExpressionClasses::find( return exp.id; } +ExpressionClasses::Id ExpressionClasses::newClass(SourceLocation const& _location) +{ + Expression exp; + exp.id = m_representatives.size(); + exp.item = storeItem(AssemblyItem(UndefinedItem, (u256(1) << 255) + exp.id, _location)); + m_representatives.push_back(exp); + m_expressions.insert(exp); + return exp.id; +} + bool ExpressionClasses::knownToBeDifferent(ExpressionClasses::Id _a, ExpressionClasses::Id _b) { // Try to simplify "_a - _b" and return true iff the value is a non-zero constant. diff --git a/libevmasm/ExpressionClasses.h b/libevmasm/ExpressionClasses.h index c83520300..dd94092e8 100644 --- a/libevmasm/ExpressionClasses.h +++ b/libevmasm/ExpressionClasses.h @@ -52,7 +52,8 @@ public: Id id; AssemblyItem const* item = nullptr; Ids arguments; - unsigned sequenceNumber; ///< Storage modification sequence, only used for SLOAD/SSTORE instructions. + /// Storage modification sequence, only used for storage and memory operations. + unsigned sequenceNumber = 0; /// Behaves as if this was a tuple of (item->type(), item->data(), arguments, sequenceNumber). bool operator<(Expression const& _other) const; }; @@ -73,6 +74,9 @@ public: /// @returns the number of classes. Id size() const { return m_representatives.size(); } + /// @returns the id of a new class which is different to all other classes. + Id newClass(SourceLocation const& _location); + /// @returns true if the values of the given classes are known to be different (on every input). /// @note that this function might still return false for some different inputs. bool knownToBeDifferent(Id _a, Id _b); diff --git a/libevmasm/KnownState.cpp b/libevmasm/KnownState.cpp index 5a70a74fb..b84e656aa 100644 --- a/libevmasm/KnownState.cpp +++ b/libevmasm/KnownState.cpp @@ -162,29 +162,41 @@ KnownState::StoreOperation KnownState::feedItem(AssemblyItem const& _item, bool /// Helper function for KnownState::reduceToCommonKnowledge, removes everything from /// _this which is not in or not equal to the value in _other. -template void intersect( - _Mapping& _this, - _Mapping const& _other, - function<_KeyType(_KeyType)> const& _keyTrans = [](_KeyType _k) { return _k; } -) +template void intersect(_Mapping& _this, _Mapping const& _other) { for (auto it = _this.begin(); it != _this.end();) - if (_other.count(_keyTrans(it->first)) && _other.at(_keyTrans(it->first)) == it->second) + if (_other.count(it->first) && _other.at(it->first) == it->second) ++it; else it = _this.erase(it); } -template void intersect(_Mapping& _this, _Mapping const& _other) -{ - intersect<_Mapping, ExpressionClasses::Id>(_this, _other, [](ExpressionClasses::Id _k) { return _k; }); -} - void KnownState::reduceToCommonKnowledge(KnownState const& _other) { int stackDiff = m_stackHeight - _other.m_stackHeight; - function stackKeyTransform = [=](int _key) -> int { return _key - stackDiff; }; - intersect(m_stackElements, _other.m_stackElements, stackKeyTransform); + for (auto it = m_stackElements.begin(); it != m_stackElements.end();) + if (_other.m_stackElements.count(it->first - stackDiff)) + { + Id other = _other.m_stackElements.at(it->first - stackDiff); + if (it->second == other) + ++it; + else + { + set theseTags = tagsInExpression(it->second); + set otherTags = tagsInExpression(other); + if (!theseTags.empty() && !otherTags.empty()) + { + theseTags.insert(otherTags.begin(), otherTags.end()); + it->second = tagUnion(theseTags); + ++it; + } + else + it = m_stackElements.erase(it); + } + } + else + it = m_stackElements.erase(it); + // Use the smaller stack height. Essential to terminate in case of loops. if (m_stackHeight > _other.m_stackHeight) { @@ -201,10 +213,15 @@ void KnownState::reduceToCommonKnowledge(KnownState const& _other) bool KnownState::operator==(const KnownState& _other) const { - return m_storageContent == _other.m_storageContent && - m_memoryContent == _other.m_memoryContent && - m_stackHeight == _other.m_stackHeight && - m_stackElements == _other.m_stackElements; + if (m_storageContent != _other.m_storageContent || m_memoryContent != _other.m_memoryContent) + return false; + int stackDiff = m_stackHeight - _other.m_stackHeight; + auto thisIt = m_stackElements.cbegin(); + auto otherIt = _other.m_stackElements.cbegin(); + for (; thisIt != m_stackElements.cend() && otherIt != _other.m_stackElements.cend(); ++thisIt, ++otherIt) + if (thisIt->first - stackDiff != otherIt->first || thisIt->second != otherIt->second) + return false; + return (thisIt == m_stackElements.cend() && otherIt == _other.m_stackElements.cend()); } ExpressionClasses::Id KnownState::stackElement(int _stackHeight, SourceLocation const& _location) @@ -212,18 +229,17 @@ ExpressionClasses::Id KnownState::stackElement(int _stackHeight, SourceLocation if (m_stackElements.count(_stackHeight)) return m_stackElements.at(_stackHeight); // Stack element not found (not assigned yet), create new unknown equivalence class. - //@todo check that we do not infer incorrect equivalences when the stack is cleared partially - //in between. - return m_stackElements[_stackHeight] = initialStackElement(_stackHeight, _location); + return m_stackElements[_stackHeight] = + m_expressionClasses->find(AssemblyItem(UndefinedItem, _stackHeight, _location)); } -ExpressionClasses::Id KnownState::initialStackElement( - int _stackHeight, - SourceLocation const& _location -) +void KnownState::clearTagUnions() { - // This is a special assembly item that refers to elements pre-existing on the initial stack. - return m_expressionClasses->find(AssemblyItem(UndefinedItem, u256(_stackHeight), _location)); + for (auto it = m_stackElements.begin(); it != m_stackElements.end();) + if (m_tagUnions.left.count(it->second)) + it = m_stackElements.erase(it); + else + ++it; } void KnownState::setStackElement(int _stackHeight, Id _class) @@ -352,3 +368,27 @@ KnownState::Id KnownState::applySha3( return m_knownSha3Hashes[arguments] = v; } +set KnownState::tagsInExpression(KnownState::Id _expressionId) +{ + if (m_tagUnions.left.count(_expressionId)) + return m_tagUnions.left.at(_expressionId); + // Might be a tag, then return the set of itself. + ExpressionClasses::Expression expr = m_expressionClasses->representative(_expressionId); + if (expr.item && expr.item->type() == PushTag) + return set({expr.item->data()}); + else + return set(); +} + +KnownState::Id KnownState::tagUnion(set _tags) +{ + if (m_tagUnions.right.count(_tags)) + return m_tagUnions.right.at(_tags); + else + { + Id id = m_expressionClasses->newClass(SourceLocation()); + m_tagUnions.right.insert(make_pair(_tags, id)); + return id; + } +} + diff --git a/libevmasm/KnownState.h b/libevmasm/KnownState.h index f7a3dd675..3505df74f 100644 --- a/libevmasm/KnownState.h +++ b/libevmasm/KnownState.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -107,15 +108,16 @@ public: /// @returns true if the knowledge about the state of both objects is (known to be) equal. bool operator==(KnownState const& _other) const; - ///@todo the sequence numbers in two copies of this class should never be the same. - /// might be doable using two-dimensional sequence numbers, where the first value is incremented - /// for each copy - /// Retrieves the current equivalence class fo the given stack element (or generates a new /// one if it does not exist yet). Id stackElement(int _stackHeight, SourceLocation const& _location); - /// @returns the equivalence class id of the special initial stack element at the given height. - Id initialStackElement(int _stackHeight, SourceLocation const& _location); + + /// @returns its set of tags if the given expression class is a known tag union; returns a set + /// containing the tag if it is a PushTag expression and the empty set otherwise. + std::set tagsInExpression(Id _expressionId); + /// During analysis, different tags on the stack are partially treated as the same class. + /// This removes such classes not to confuse later analyzers. + void clearTagUnions(); int stackHeight() const { return m_stackHeight; } std::map const& stackElements() const { return m_stackElements; } @@ -142,6 +144,9 @@ private: /// Finds or creates a new expression that applies the sha3 hash function to the contents in memory. Id applySha3(Id _start, Id _length, SourceLocation const& _location); + /// @returns a new or already used Id representing the given set of tags. + Id tagUnion(std::set _tags); + /// Current stack height, can be negative. int m_stackHeight = 0; /// Current stack layout, mapping stack height -> equivalence class @@ -157,6 +162,8 @@ private: std::map, Id> m_knownSha3Hashes; /// Structure containing the classes of equivalent expressions. std::shared_ptr m_expressionClasses; + /// Container for unions of tags stored on the stack. + boost::bimap> m_tagUnions; }; } diff --git a/test/libsolidity/SolidityOptimizer.cpp b/test/libsolidity/SolidityOptimizer.cpp index efc9316b0..ce43887e1 100644 --- a/test/libsolidity/SolidityOptimizer.cpp +++ b/test/libsolidity/SolidityOptimizer.cpp @@ -315,6 +315,49 @@ BOOST_AUTO_TEST_CASE(retain_information_in_branches) BOOST_CHECK_EQUAL(1, numSHA3s); } +BOOST_AUTO_TEST_CASE(store_tags_as_unions) +{ + // This calls the same function from two sources and both calls have a certain sha3 on + // the stack at the same position. + // Without storing tags as unions, the return from the shared function would not know where to + // jump and thus all jumpdests are forced to clear their state and we do not know about the + // sha3 anymore. + // Note that, for now, this only works if the functions have the same number of return + // parameters since otherwise, the return jump addresses are at different stack positions + // which triggers the "unknown jump target" situation. + char const* sourceCode = R"( + contract test { + bytes32 data; + function f(uint x, bytes32 y) external returns (uint r_a, bytes32 r_d) { + r_d = sha3(y); + shared(y); + r_d = sha3(y); + r_a = 5; + } + function g(uint x, bytes32 y) external returns (uint r_a, bytes32 r_d) { + r_d = sha3(y); + shared(y); + r_d = bytes32(uint(sha3(y)) + 2); + r_a = 7; + } + function shared(bytes32 y) internal { + data = sha3(y); + } + } + )"; + compileBothVersions(sourceCode); + compareVersions("f()", 7, "abc"); + + m_optimize = true; + bytes optimizedBytecode = compileAndRun(sourceCode, 0, "test"); + size_t numSHA3s = 0; + eth::eachInstruction(optimizedBytecode, [&](Instruction _instr, u256 const&) { + if (_instr == eth::Instruction::SHA3) + numSHA3s++; + }); + BOOST_CHECK_EQUAL(2, numSHA3s); +} + BOOST_AUTO_TEST_CASE(cse_intermediate_swap) { eth::KnownState state;