diff --git a/libsolidity/CallGraph.cpp b/libsolidity/CallGraph.cpp deleted file mode 100644 index 5f8fc5470..000000000 --- a/libsolidity/CallGraph.cpp +++ /dev/null @@ -1,105 +0,0 @@ - -/* - This file is part of cpp-ethereum. - - cpp-ethereum is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - cpp-ethereum is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with cpp-ethereum. If not, see . -*/ -/** - * @author Christian - * @date 2014 - * Callgraph of functions inside a contract. - */ - -#include -#include - -using namespace std; - -namespace dev -{ -namespace solidity -{ - -void CallGraph::addNode(ASTNode const& _node) -{ - if (!m_nodesSeen.count(&_node)) - { - m_workQueue.push(&_node); - m_nodesSeen.insert(&_node); - } -} - -set const& CallGraph::getCalls() -{ - computeCallGraph(); - return m_functionsSeen; -} - -void CallGraph::computeCallGraph() -{ - while (!m_workQueue.empty()) - { - m_workQueue.front()->accept(*this); - m_workQueue.pop(); - } -} - -bool CallGraph::visit(Identifier const& _identifier) -{ - if (auto fun = dynamic_cast(_identifier.getReferencedDeclaration())) - { - if (m_functionOverrideResolver) - fun = (*m_functionOverrideResolver)(fun->getName()); - solAssert(fun, "Error finding override for function " + fun->getName()); - addNode(*fun); - } - if (auto modifier = dynamic_cast(_identifier.getReferencedDeclaration())) - { - if (m_modifierOverrideResolver) - modifier = (*m_modifierOverrideResolver)(modifier->getName()); - solAssert(modifier, "Error finding override for modifier " + modifier->getName()); - addNode(*modifier); - } - return true; -} - -bool CallGraph::visit(FunctionDefinition const& _function) -{ - m_functionsSeen.insert(&_function); - return true; -} - -bool CallGraph::visit(MemberAccess const& _memberAccess) -{ - // used for "BaseContract.baseContractFunction" - if (_memberAccess.getExpression().getType()->getCategory() == Type::Category::TYPE) - { - TypeType const& type = dynamic_cast(*_memberAccess.getExpression().getType()); - if (type.getMembers().getMemberType(_memberAccess.getMemberName())) - { - ContractDefinition const& contract = dynamic_cast(*type.getActualType()) - .getContractDefinition(); - for (ASTPointer const& function: contract.getDefinedFunctions()) - if (function->getName() == _memberAccess.getMemberName()) - { - addNode(*function); - return true; - } - } - } - return true; -} - -} -} diff --git a/libsolidity/CallGraph.h b/libsolidity/CallGraph.h deleted file mode 100644 index 9af5cdf97..000000000 --- a/libsolidity/CallGraph.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - This file is part of cpp-ethereum. - - cpp-ethereum is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - cpp-ethereum is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with cpp-ethereum. If not, see . -*/ -/** - * @author Christian - * @date 2014 - * Callgraph of functions inside a contract. - */ - -#include -#include -#include -#include -#include - -namespace dev -{ -namespace solidity -{ - -/** - * Can be used to compute the graph of calls (or rather references) between functions of the same - * contract. Current functionality is limited to computing all functions that are directly - * or indirectly called by some functions. - */ -class CallGraph: private ASTConstVisitor -{ -public: - using FunctionOverrideResolver = std::function; - using ModifierOverrideResolver = std::function; - - CallGraph(FunctionOverrideResolver const& _functionOverrideResolver, - ModifierOverrideResolver const& _modifierOverrideResolver): - m_functionOverrideResolver(&_functionOverrideResolver), - m_modifierOverrideResolver(&_modifierOverrideResolver) {} - - void addNode(ASTNode const& _node); - - std::set const& getCalls(); - -private: - virtual bool visit(FunctionDefinition const& _function) override; - virtual bool visit(Identifier const& _identifier) override; - virtual bool visit(MemberAccess const& _memberAccess) override; - - void computeCallGraph(); - - FunctionOverrideResolver const* m_functionOverrideResolver; - ModifierOverrideResolver const* m_modifierOverrideResolver; - std::set m_nodesSeen; - std::set m_functionsSeen; - std::queue m_workQueue; -}; - -} -} diff --git a/libsolidity/Compiler.cpp b/libsolidity/Compiler.cpp index c7656363a..93784adf2 100644 --- a/libsolidity/Compiler.cpp +++ b/libsolidity/Compiler.cpp @@ -28,7 +28,6 @@ #include #include #include -#include using namespace std; @@ -40,31 +39,13 @@ void Compiler::compileContract(ContractDefinition const& _contract, { m_context = CompilerContext(); // clear it just in case initializeContext(_contract, _contracts); - - for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) - { - for (ASTPointer const& function: contract->getDefinedFunctions()) - if (!function->isConstructor()) - m_context.addFunction(*function); - - for (ASTPointer const& vardecl: contract->getStateVariables()) - if (vardecl->isPublic()) - m_context.addFunction(*vardecl); - - for (ASTPointer const& modifier: contract->getFunctionModifiers()) - m_context.addModifier(*modifier); - } - appendFunctionSelector(_contract); - for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) + set functions = m_context.getFunctionsWithoutCode(); + while (!functions.empty()) { - for (ASTPointer const& function: contract->getDefinedFunctions()) - if (!function->isConstructor()) - function->accept(*this); - - for (ASTPointer const& vardecl: contract->getStateVariables()) - if (vardecl->isPublic()) - generateAccessorCode(*vardecl); + for (Declaration const* function: functions) + function->accept(*this); + functions = m_context.getFunctionsWithoutCode(); } // Swap the runtime context with the creation-time context @@ -77,72 +58,26 @@ void Compiler::initializeContext(ContractDefinition const& _contract, map const& _contracts) { m_context.setCompiledContracts(_contracts); + m_context.setInheritanceHierarchy(_contract.getLinearizedBaseContracts()); registerStateVariables(_contract); } void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) { - std::vector const& bases = _contract.getLinearizedBaseContracts(); - - // Make all modifiers known to the context. - for (ContractDefinition const* contract: bases) - for (ASTPointer const& modifier: contract->getFunctionModifiers()) - m_context.addModifier(*modifier); - // arguments for base constructors, filled in derived-to-base order map> const*> baseArguments; - set neededFunctions; - set nodesUsedInConstructors; - // Determine the arguments that are used for the base constructors and also which functions - // are needed at compile time. + // Determine the arguments that are used for the base constructors. + std::vector const& bases = _contract.getLinearizedBaseContracts(); for (ContractDefinition const* contract: bases) - { - if (FunctionDefinition const* constructor = contract->getConstructor()) - nodesUsedInConstructors.insert(constructor); for (ASTPointer const& base: contract->getBaseContracts()) { ContractDefinition const* baseContract = dynamic_cast( base->getName()->getReferencedDeclaration()); solAssert(baseContract, ""); if (baseArguments.count(baseContract) == 0) - { baseArguments[baseContract] = &base->getArguments(); - for (ASTPointer const& arg: base->getArguments()) - nodesUsedInConstructors.insert(arg.get()); - } } - } - - auto functionOverrideResolver = [&](string const& _name) -> FunctionDefinition const* - { - for (ContractDefinition const* contract: bases) - for (ASTPointer const& function: contract->getDefinedFunctions()) - if (!function->isConstructor() && function->getName() == _name) - return function.get(); - return nullptr; - }; - auto modifierOverrideResolver = [&](string const& _name) -> ModifierDefinition const* - { - return &m_context.getFunctionModifier(_name); - }; - - neededFunctions = getFunctionsCalled(nodesUsedInConstructors, functionOverrideResolver, - modifierOverrideResolver); - - // First add all overrides (or the functions themselves if there is no override) - for (FunctionDefinition const* fun: neededFunctions) - { - FunctionDefinition const* override = nullptr; - if (!fun->isConstructor()) - override = functionOverrideResolver(fun->getName()); - if (!!override && neededFunctions.count(override)) - m_context.addFunction(*override); - } - // now add the rest - for (FunctionDefinition const* fun: neededFunctions) - if (fun->isConstructor() || functionOverrideResolver(fun->getName()) != fun) - m_context.addFunction(*fun); // Call constructors in base-to-derived order. // The Constructor for the most derived contract is called later. @@ -164,10 +99,14 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp m_context << eth::Instruction::DUP1 << sub << u256(0) << eth::Instruction::CODECOPY; m_context << u256(0) << eth::Instruction::RETURN; - // note that we have to explicitly include all used functions because of absolute jump - // labels - for (FunctionDefinition const* fun: neededFunctions) - fun->accept(*this); + // note that we have to include the functions again because of absolute jump labels + set functions = m_context.getFunctionsWithoutCode(); + while (!functions.empty()) + { + for (Declaration const* function: functions) + function->accept(*this); + functions = m_context.getFunctionsWithoutCode(); + } } void Compiler::appendBaseConstructorCall(FunctionDefinition const& _constructor, @@ -201,16 +140,6 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) m_context << returnTag; } -set Compiler::getFunctionsCalled(set const& _nodes, - function const& _resolveFunctionOverrides, - function const& _resolveModifierOverrides) -{ - CallGraph callgraph(_resolveFunctionOverrides, _resolveModifierOverrides); - for (ASTNode const* node: _nodes) - callgraph.addNode(*node); - return callgraph.getCalls(); -} - void Compiler::appendFunctionSelector(ContractDefinition const& _contract) { map, FunctionDescription> interfaceFunctions = _contract.getInterfaceFunctions(); @@ -292,19 +221,22 @@ void Compiler::registerStateVariables(ContractDefinition const& _contract) m_context.addStateVariable(*variable); } -void Compiler::generateAccessorCode(VariableDeclaration const& _varDecl) +bool Compiler::visit(VariableDeclaration const& _variableDeclaration) { - m_context.startNewFunction(); + solAssert(_variableDeclaration.isStateVariable(), "Compiler visit to non-state variable declaration."); + + m_context.startFunction(_variableDeclaration); m_breakTags.clear(); m_continueTags.clear(); - m_context << m_context.getFunctionEntryLabel(_varDecl); - ExpressionCompiler::appendStateVariableAccessor(m_context, _varDecl); + m_context << m_context.getFunctionEntryLabel(_variableDeclaration); + ExpressionCompiler::appendStateVariableAccessor(m_context, _variableDeclaration); - unsigned sizeOnStack = _varDecl.getType()->getSizeOnStack(); - solAssert(sizeOnStack <= 15, "Illegal variable stack size detected"); - m_context << eth::dupInstruction(sizeOnStack + 1); - m_context << eth::Instruction::JUMP; + unsigned sizeOnStack = _variableDeclaration.getType()->getSizeOnStack(); + solAssert(sizeOnStack <= 15, "Stack too deep."); + m_context << eth::dupInstruction(sizeOnStack + 1) << eth::Instruction::JUMP; + + return false; } bool Compiler::visit(FunctionDefinition const& _function) @@ -313,7 +245,7 @@ bool Compiler::visit(FunctionDefinition const& _function) // caller puts: [retarg0] ... [retargm] [return address] [arg0] ... [argn] // although note that this reduces the size of the visible stack - m_context.startNewFunction(); + m_context.startFunction(_function); m_returnTag = m_context.newTag(); m_breakTags.clear(); m_continueTags.clear(); @@ -321,8 +253,6 @@ bool Compiler::visit(FunctionDefinition const& _function) m_currentFunction = &_function; m_modifierDepth = 0; - m_context << m_context.getFunctionEntryLabel(_function); - // stack upon entry: [return address] [arg0] [arg1] ... [argn] // reserve additional slots: [retarg0] ... [retargm] [localvar0] ... [localvarp] diff --git a/libsolidity/Compiler.h b/libsolidity/Compiler.h index 144af8eb8..b3eae5b17 100644 --- a/libsolidity/Compiler.h +++ b/libsolidity/Compiler.h @@ -50,11 +50,6 @@ private: void appendBaseConstructorCall(FunctionDefinition const& _constructor, std::vector> const& _arguments); void appendConstructorCall(FunctionDefinition const& _constructor); - /// Recursively searches the call graph and returns all functions referenced inside _nodes. - /// _resolveOverride is called to resolve virtual function overrides. - std::set getFunctionsCalled(std::set const& _nodes, - std::function const& _resolveFunctionOverride, - std::function const& _resolveModifierOverride); void appendFunctionSelector(ContractDefinition const& _contract); /// Creates code that unpacks the arguments for the given function represented by a vector of TypePointers. /// From memory if @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. @@ -63,8 +58,7 @@ private: void registerStateVariables(ContractDefinition const& _contract); - void generateAccessorCode(VariableDeclaration const& _varDecl); - + virtual bool visit(VariableDeclaration const& _variableDeclaration) override; virtual bool visit(FunctionDefinition const& _function) override; virtual bool visit(IfStatement const& _ifStatement) override; virtual bool visit(WhileStatement const& _whileStatement) override; diff --git a/libsolidity/CompilerContext.cpp b/libsolidity/CompilerContext.cpp index ea349c0d2..52910a556 100644 --- a/libsolidity/CompilerContext.cpp +++ b/libsolidity/CompilerContext.cpp @@ -43,6 +43,14 @@ void CompilerContext::addStateVariable(VariableDeclaration const& _declaration) m_stateVariablesSize += _declaration.getType()->getStorageSize(); } +void CompilerContext::startFunction(Declaration const& _function) +{ + m_functionsWithCode.insert(&_function); + m_localVariables.clear(); + m_asm.setDeposit(0); + *this << getFunctionEntryLabel(_function); +} + void CompilerContext::addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent) { @@ -59,18 +67,6 @@ void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _decla *this << u256(0); } -void CompilerContext::addFunction(Declaration const& _decl) -{ - eth::AssemblyItem tag(m_asm.newTag()); - m_functionEntryLabels.insert(make_pair(&_decl, tag)); - m_virtualFunctionEntryLabels.insert(make_pair(_decl.getName(), tag)); -} - -void CompilerContext::addModifier(ModifierDefinition const& _modifier) -{ - m_functionModifiers.insert(make_pair(_modifier.getName(), &_modifier)); -} - bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const { auto ret = m_compiledContracts.find(&_contract); @@ -83,25 +79,62 @@ bool CompilerContext::isLocalVariable(Declaration const* _declaration) const return m_localVariables.count(_declaration); } -eth::AssemblyItem CompilerContext::getFunctionEntryLabel(Declaration const& _declaration) const +eth::AssemblyItem CompilerContext::getFunctionEntryLabel(Declaration const& _declaration) { auto res = m_functionEntryLabels.find(&_declaration); - solAssert(res != m_functionEntryLabels.end(), "Function entry label not found."); - return res->second.tag(); + if (res == m_functionEntryLabels.end()) + { + eth::AssemblyItem tag(m_asm.newTag()); + m_functionEntryLabels.insert(make_pair(&_declaration, tag)); + return tag.tag(); + } + else + return res->second.tag(); +} + +eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) +{ + solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); + for (ContractDefinition const* contract: m_inheritanceHierarchy) + for (ASTPointer const& function: contract->getDefinedFunctions()) + if (!function->isConstructor() && function->getName() == _function.getName()) + return getFunctionEntryLabel(*function); + solAssert(false, "Virtual function " + _function.getName() + " not found."); + return m_asm.newTag(); // not reached +} + +eth::AssemblyItem CompilerContext::getSuperFunctionEntryLabel(string const& _name, ContractDefinition const& _base) +{ + // search for first contract after _base + solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); + auto it = find(m_inheritanceHierarchy.begin(), m_inheritanceHierarchy.end(), &_base); + solAssert(it != m_inheritanceHierarchy.end(), "Base not found in inheritance hierarchy."); + for (++it; it != m_inheritanceHierarchy.end(); ++it) + for (ASTPointer const& function: (*it)->getDefinedFunctions()) + if (!function->isConstructor() && function->getName() == _name) + return getFunctionEntryLabel(*function); + solAssert(false, "Super function " + _name + " not found."); + return m_asm.newTag(); // not reached } -eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const +set CompilerContext::getFunctionsWithoutCode() { - auto res = m_virtualFunctionEntryLabels.find(_function.getName()); - solAssert(res != m_virtualFunctionEntryLabels.end(), "Function entry label not found."); - return res->second.tag(); + set functions; + for (auto const& it: m_functionEntryLabels) + if (m_functionsWithCode.count(it.first) == 0) + functions.insert(it.first); + return move(functions); } ModifierDefinition const& CompilerContext::getFunctionModifier(string const& _name) const { - auto res = m_functionModifiers.find(_name); - solAssert(res != m_functionModifiers.end(), "Function modifier override not found."); - return *res->second; + solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set."); + for (ContractDefinition const* contract: m_inheritanceHierarchy) + for (ASTPointer const& modifier: contract->getFunctionModifiers()) + if (modifier->getName() == _name) + return *modifier.get(); + BOOST_THROW_EXCEPTION(InternalCompilerError() + << errinfo_comment("Function modifier " + _name + " not found.")); } unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const diff --git a/libsolidity/CompilerContext.h b/libsolidity/CompilerContext.h index 9de3385a6..6d6a65b61 100644 --- a/libsolidity/CompilerContext.h +++ b/libsolidity/CompilerContext.h @@ -41,12 +41,8 @@ class CompilerContext public: void addMagicGlobal(MagicVariableDeclaration const& _declaration); void addStateVariable(VariableDeclaration const& _declaration); - void startNewFunction() { m_localVariables.clear(); m_asm.setDeposit(0); } void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0); void addAndInitializeVariable(VariableDeclaration const& _declaration); - void addFunction(Declaration const& _decl); - /// Adds the given modifier to the list by name if the name is not present already. - void addModifier(ModifierDefinition const& _modifier); void setCompiledContracts(std::map const& _contracts) { m_compiledContracts = _contracts; } bytes const& getCompiledContract(ContractDefinition const& _contract) const; @@ -54,13 +50,22 @@ public: void adjustStackOffset(int _adjustment) { m_asm.adjustDeposit(_adjustment); } bool isMagicGlobal(Declaration const* _declaration) const { return m_magicGlobals.count(_declaration) != 0; } - bool isFunctionDefinition(Declaration const* _declaration) const { return m_functionEntryLabels.count(_declaration) != 0; } bool isLocalVariable(Declaration const* _declaration) const; bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; } - eth::AssemblyItem getFunctionEntryLabel(Declaration const& _declaration) const; + eth::AssemblyItem getFunctionEntryLabel(Declaration const& _declaration); + void setInheritanceHierarchy(std::vector const& _hierarchy) { m_inheritanceHierarchy = _hierarchy; } /// @returns the entry label of the given function and takes overrides into account. - eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const; + eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function); + /// @returns the entry label of function with the given name from the most derived class just + /// above _base in the current inheritance hierarchy. + eth::AssemblyItem getSuperFunctionEntryLabel(std::string const& _name, ContractDefinition const& _base); + /// @returns the set of functions for which we still need to generate code + std::set getFunctionsWithoutCode(); + /// Resets function specific members, inserts the function entry label and marks the function + /// as "having code". + void startFunction(Declaration const& _function); + ModifierDefinition const& getFunctionModifier(std::string const& _name) const; /// Returns the distance of the given local variable from the bottom of the stack (of the current function). unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; @@ -119,10 +124,10 @@ private: std::map m_localVariables; /// Labels pointing to the entry points of functions. std::map m_functionEntryLabels; - /// Labels pointing to the entry points of function overrides. - std::map m_virtualFunctionEntryLabels; - /// Mapping to obtain function modifiers by name. Should be filled from derived to base. - std::map m_functionModifiers; + /// Set of functions for which we did not yet generate code. + std::set m_functionsWithCode; + /// List of current inheritance hierarchy from derived to base. + std::vector m_inheritanceHierarchy; }; } diff --git a/libsolidity/CompilerStack.cpp b/libsolidity/CompilerStack.cpp index 0b8218bb3..3ed0d3620 100644 --- a/libsolidity/CompilerStack.cpp +++ b/libsolidity/CompilerStack.cpp @@ -94,6 +94,7 @@ void CompilerStack::parse() { m_globalContext->setCurrentContract(*contract); resolver.updateDeclaration(*m_globalContext->getCurrentThis()); + resolver.updateDeclaration(*m_globalContext->getCurrentSuper()); resolver.resolveNamesAndTypes(*contract); m_contracts[contract->getName()].contract = contract; } diff --git a/libsolidity/ExpressionCompiler.cpp b/libsolidity/ExpressionCompiler.cpp index 15ee17fd3..bcd90acfc 100644 --- a/libsolidity/ExpressionCompiler.cpp +++ b/libsolidity/ExpressionCompiler.cpp @@ -365,15 +365,25 @@ void ExpressionCompiler::endVisit(MemberAccess const& _memberAccess) { case Type::Category::CONTRACT: { + bool alsoSearchInteger = false; ContractType const& type = dynamic_cast(*_memberAccess.getExpression().getType()); - u256 identifier = type.getFunctionIdentifier(member); - if (identifier != Invalid256) + if (type.isSuper()) + m_context << m_context.getSuperFunctionEntryLabel(member, type.getContractDefinition()).pushTag(); + else { - appendTypeConversion(type, IntegerType(0, IntegerType::Modifier::ADDRESS), true); - m_context << identifier; - break; + // ordinary contract type + u256 identifier = type.getFunctionIdentifier(member); + if (identifier != Invalid256) + { + appendTypeConversion(type, IntegerType(0, IntegerType::Modifier::ADDRESS), true); + m_context << identifier; + } + else + // not found in contract, search in members inherited from address + alsoSearchInteger = true; } - // fall-through to "integer" otherwise (address) + if (!alsoSearchInteger) + break; } case Type::Category::INTEGER: if (member == "balance") @@ -469,8 +479,10 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier) Declaration const* declaration = _identifier.getReferencedDeclaration(); if (MagicVariableDeclaration const* magicVar = dynamic_cast(declaration)) { - if (magicVar->getType()->getCategory() == Type::Category::CONTRACT) // must be "this" - m_context << eth::Instruction::ADDRESS; + if (magicVar->getType()->getCategory() == Type::Category::CONTRACT) + // "this" or "super" + if (!dynamic_cast(*magicVar->getType()).isSuper()) + m_context << eth::Instruction::ADDRESS; } else if (FunctionDefinition const* functionDef = dynamic_cast(declaration)) m_context << m_context.getVirtualFunctionEntryLabel(*functionDef).pushTag(); diff --git a/libsolidity/GlobalContext.cpp b/libsolidity/GlobalContext.cpp index c7eea92dc..687c9c9d4 100644 --- a/libsolidity/GlobalContext.cpp +++ b/libsolidity/GlobalContext.cpp @@ -83,5 +83,13 @@ MagicVariableDeclaration const* GlobalContext::getCurrentThis() const } +MagicVariableDeclaration const* GlobalContext::getCurrentSuper() const +{ + if (!m_superPointer[m_currentContract]) + m_superPointer[m_currentContract] = make_shared( + "super", make_shared(*m_currentContract, true)); + return m_superPointer[m_currentContract].get(); +} + } } diff --git a/libsolidity/GlobalContext.h b/libsolidity/GlobalContext.h index dfdc66623..f861c67d7 100644 --- a/libsolidity/GlobalContext.h +++ b/libsolidity/GlobalContext.h @@ -48,6 +48,7 @@ public: GlobalContext(); void setCurrentContract(ContractDefinition const& _contract); MagicVariableDeclaration const* getCurrentThis() const; + MagicVariableDeclaration const* getCurrentSuper() const; /// @returns a vector of all implicit global declarations excluding "this". std::vector getDeclarations() const; @@ -56,6 +57,7 @@ private: std::vector> m_magicVariables; ContractDefinition const* m_currentContract = nullptr; std::map> mutable m_thisPointer; + std::map> mutable m_superPointer; }; } diff --git a/libsolidity/Types.cpp b/libsolidity/Types.cpp index fcb10d4b5..3d6c4e96c 100644 --- a/libsolidity/Types.cpp +++ b/libsolidity/Types.cpp @@ -450,7 +450,9 @@ bool ContractType::isImplicitlyConvertibleTo(Type const& _convertTo) const if (_convertTo.getCategory() == Category::CONTRACT) { auto const& bases = getContractDefinition().getLinearizedBaseContracts(); - return find(bases.begin(), bases.end(), + if (m_super && bases.size() <= 1) + return false; + return find(m_super ? ++bases.begin() : bases.begin(), bases.end(), &dynamic_cast(_convertTo).getContractDefinition()) != bases.end(); } return false; @@ -472,12 +474,12 @@ bool ContractType::operator==(Type const& _other) const if (_other.getCategory() != getCategory()) return false; ContractType const& other = dynamic_cast(_other); - return other.m_contract == m_contract; + return other.m_contract == m_contract && other.m_super == m_super; } string ContractType::toString() const { - return "contract " + m_contract.getName(); + return "contract " + string(m_super ? "super " : "") + m_contract.getName(); } MemberList const& ContractType::getMembers() const @@ -488,8 +490,16 @@ MemberList const& ContractType::getMembers() const // All address members and all interface functions map> members(IntegerType::AddressMemberList.begin(), IntegerType::AddressMemberList.end()); - for (auto const& it: m_contract.getInterfaceFunctions()) - members[it.second.getName()] = it.second.getFunctionTypeShared(); + if (m_super) + { + for (ContractDefinition const* base: m_contract.getLinearizedBaseContracts()) + for (ASTPointer const& function: base->getDefinedFunctions()) + if (!function->isConstructor()) + members.insert(make_pair(function->getName(), make_shared(*function, true))); + } + else + for (auto const& it: m_contract.getInterfaceFunctions()) + members[it.second.getName()] = it.second.getFunctionTypeShared(); m_members.reset(new MemberList(members)); } return *m_members; diff --git a/libsolidity/Types.h b/libsolidity/Types.h index 2dbed95fa..3f6df13ee 100644 --- a/libsolidity/Types.h +++ b/libsolidity/Types.h @@ -277,7 +277,8 @@ class ContractType: public Type { public: virtual Category getCategory() const override { return Category::CONTRACT; } - ContractType(ContractDefinition const& _contract): m_contract(_contract) {} + explicit ContractType(ContractDefinition const& _contract, bool _super = false): + m_contract(_contract), m_super(_super) {} /// Contracts can be implicitly converted to super classes and to addresses. virtual bool isImplicitlyConvertibleTo(Type const& _convertTo) const override; /// Contracts can be converted to themselves and to integers. @@ -289,6 +290,7 @@ public: virtual MemberList const& getMembers() const override; + bool isSuper() const { return m_super; } ContractDefinition const& getContractDefinition() const { return m_contract; } /// Returns the function type of the constructor. Note that the location part of the function type @@ -301,6 +303,9 @@ public: private: ContractDefinition const& m_contract; + /// If true, it is the "super" type of the current contract, i.e. it contains only inherited + /// members. + bool m_super; /// Type of the constructor, @see getConstructorType. Lazily initialized. mutable std::shared_ptr m_constructorType; /// List of member types, will be lazy-initialized because of recursive references. @@ -314,7 +319,7 @@ class StructType: public Type { public: virtual Category getCategory() const override { return Category::STRUCT; } - StructType(StructDefinition const& _struct): m_struct(_struct) {} + explicit StructType(StructDefinition const& _struct): m_struct(_struct) {} virtual TypePointer unaryOperatorResult(Token::Value _operator) const override; virtual bool operator==(Type const& _other) const override; virtual u256 getStorageSize() const override; @@ -448,7 +453,7 @@ class TypeType: public Type { public: virtual Category getCategory() const override { return Category::TYPE; } - TypeType(TypePointer const& _actualType, ContractDefinition const* _currentContract = nullptr): + explicit TypeType(TypePointer const& _actualType, ContractDefinition const* _currentContract = nullptr): m_actualType(_actualType), m_currentContract(_currentContract) {} TypePointer const& getActualType() const { return m_actualType; } @@ -502,7 +507,7 @@ public: enum class Kind { BLOCK, MSG, TX }; virtual Category getCategory() const override { return Category::MAGIC; } - MagicType(Kind _kind); + explicit MagicType(Kind _kind); virtual TypePointer binaryOperatorResult(Token::Value, TypePointer const&) const override { diff --git a/test/SolidityCompiler.cpp b/test/SolidityCompiler.cpp index 53daa9dfe..98397af79 100644 --- a/test/SolidityCompiler.cpp +++ b/test/SolidityCompiler.cpp @@ -108,57 +108,6 @@ BOOST_AUTO_TEST_CASE(smoke_test) checkCodePresentAt(code, expectation, boilerplateSize); } -BOOST_AUTO_TEST_CASE(different_argument_numbers) -{ - char const* sourceCode = "contract test {\n" - " function f(uint a, uint b, uint c) returns(uint d) { return b; }\n" - " function g() returns (uint e, uint h) { h = f(1, 2, 3); }\n" - "}\n"; - bytes code = compileContract(sourceCode); - unsigned shift = 103; - unsigned boilerplateSize = 116; - bytes expectation({byte(Instruction::JUMPDEST), - byte(Instruction::PUSH1), 0x0, // initialize return variable d - byte(Instruction::DUP3), - byte(Instruction::SWAP1), // assign b to d - byte(Instruction::POP), - byte(Instruction::PUSH1), byte(0xa + shift), // jump to return - byte(Instruction::JUMP), - byte(Instruction::JUMPDEST), - byte(Instruction::SWAP4), // store d and fetch return address - byte(Instruction::SWAP3), // store return address - byte(Instruction::POP), - byte(Instruction::POP), - byte(Instruction::POP), - byte(Instruction::JUMP), // end of f - byte(Instruction::JUMPDEST), // beginning of g - byte(Instruction::PUSH1), 0x0, - byte(Instruction::PUSH1), 0x0, // initialized e and h - byte(Instruction::PUSH1), byte(0x21 + shift), // ret address - byte(Instruction::PUSH1), 0x1, - byte(Instruction::PUSH1), 0x2, - byte(Instruction::PUSH1), 0x3, - byte(Instruction::PUSH1), byte(0x1 + shift), - // stack here: ret e h 0x20 1 2 3 0x1 - byte(Instruction::JUMP), - byte(Instruction::JUMPDEST), - // stack here: ret e h f(1,2,3) - byte(Instruction::SWAP1), - // stack here: ret e f(1,2,3) h - byte(Instruction::POP), - byte(Instruction::DUP1), // retrieve it again as "value of expression" - byte(Instruction::POP), // end of assignment - // stack here: ret e f(1,2,3) - byte(Instruction::JUMPDEST), - byte(Instruction::SWAP1), - // ret e f(1,2,3) - byte(Instruction::SWAP2), - // f(1,2,3) e ret - byte(Instruction::JUMP) // end of g - }); - checkCodePresentAt(code, expectation, boilerplateSize); -} - BOOST_AUTO_TEST_CASE(ifStatement) { char const* sourceCode = "contract test {\n" diff --git a/test/SolidityEndToEndTest.cpp b/test/SolidityEndToEndTest.cpp index 1ddb22731..1450095af 100644 --- a/test/SolidityEndToEndTest.cpp +++ b/test/SolidityEndToEndTest.cpp @@ -1930,6 +1930,30 @@ BOOST_AUTO_TEST_CASE(crazy_elementary_typenames_on_stack) BOOST_CHECK(callContractFunction("f()") == encodeArgs(u256(-7))); } +BOOST_AUTO_TEST_CASE(super) +{ + char const* sourceCode = R"( + contract A { function f() returns (uint r) { return 1; } } + contract B is A { function f() returns (uint r) { return super.f() | 2; } } + contract C is A { function f() returns (uint r) { return super.f() | 4; } } + contract D is B, C { function f() returns (uint r) { return super.f() | 8; } } + )"; + compileAndRun(sourceCode, 0, "D"); + BOOST_CHECK(callContractFunction("f()") == encodeArgs(1 | 2 | 4 | 8)); +} + +BOOST_AUTO_TEST_CASE(super_in_constructor) +{ + char const* sourceCode = R"( + contract A { function f() returns (uint r) { return 1; } } + contract B is A { function f() returns (uint r) { return super.f() | 2; } } + contract C is A { function f() returns (uint r) { return super.f() | 4; } } + contract D is B, C { uint data; function D() { data = super.f() | 8; } function f() returns (uint r) { return data; } } + )"; + compileAndRun(sourceCode, 0, "D"); + BOOST_CHECK(callContractFunction("f()") == encodeArgs(1 | 2 | 4 | 8)); +} + BOOST_AUTO_TEST_SUITE_END() } diff --git a/test/SolidityExpressionCompiler.cpp b/test/SolidityExpressionCompiler.cpp index 06c252db3..a0cca3a3a 100644 --- a/test/SolidityExpressionCompiler.cpp +++ b/test/SolidityExpressionCompiler.cpp @@ -86,7 +86,8 @@ Declaration const& resolveDeclaration(vector const& _namespacedName, } bytes compileFirstExpression(const string& _sourceCode, vector> _functions = {}, - vector> _localVariables = {}, vector> _globalDeclarations = {}) + vector> _localVariables = {}, + vector> _globalDeclarations = {}) { Parser parser; ASTPointer sourceUnit; @@ -99,10 +100,12 @@ bytes compileFirstExpression(const string& _sourceCode, vector> _ NameAndTypeResolver resolver(declarations); resolver.registerDeclarations(*sourceUnit); + vector inheritanceHierarchy; for (ASTPointer const& node: sourceUnit->getNodes()) if (ContractDefinition* contract = dynamic_cast(node.get())) { BOOST_REQUIRE_NO_THROW(resolver.resolveNamesAndTypes(*contract)); + inheritanceHierarchy = vector(1, contract); } for (ASTPointer const& node: sourceUnit->getNodes()) if (ContractDefinition* contract = dynamic_cast(node.get())) @@ -116,8 +119,7 @@ bytes compileFirstExpression(const string& _sourceCode, vector> _ BOOST_REQUIRE(extractor.getExpression() != nullptr); CompilerContext context; - for (vector const& function: _functions) - context.addFunction(dynamic_cast(resolveDeclaration(function, resolver))); + context.setInheritanceHierarchy(inheritanceHierarchy); unsigned parametersSize = _localVariables.size(); // assume they are all one slot on the stack context.adjustStackOffset(parametersSize); for (vector const& variable: _localVariables)