diff --git a/libsolidity/CallGraph.cpp b/libsolidity/CallGraph.cpp
deleted file mode 100644
index 5f8fc5470..000000000
--- a/libsolidity/CallGraph.cpp
+++ /dev/null
@@ -1,105 +0,0 @@
-
-/*
- This file is part of cpp-ethereum.
-
- cpp-ethereum is free software: you can redistribute it and/or modify
- it under the terms of the GNU General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- cpp-ethereum is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU General Public License for more details.
-
- You should have received a copy of the GNU General Public License
- along with cpp-ethereum. If not, see .
-*/
-/**
- * @author Christian
- * @date 2014
- * Callgraph of functions inside a contract.
- */
-
-#include
-#include
-
-using namespace std;
-
-namespace dev
-{
-namespace solidity
-{
-
-void CallGraph::addNode(ASTNode const& _node)
-{
- if (!m_nodesSeen.count(&_node))
- {
- m_workQueue.push(&_node);
- m_nodesSeen.insert(&_node);
- }
-}
-
-set const& CallGraph::getCalls()
-{
- computeCallGraph();
- return m_functionsSeen;
-}
-
-void CallGraph::computeCallGraph()
-{
- while (!m_workQueue.empty())
- {
- m_workQueue.front()->accept(*this);
- m_workQueue.pop();
- }
-}
-
-bool CallGraph::visit(Identifier const& _identifier)
-{
- if (auto fun = dynamic_cast(_identifier.getReferencedDeclaration()))
- {
- if (m_functionOverrideResolver)
- fun = (*m_functionOverrideResolver)(fun->getName());
- solAssert(fun, "Error finding override for function " + fun->getName());
- addNode(*fun);
- }
- if (auto modifier = dynamic_cast(_identifier.getReferencedDeclaration()))
- {
- if (m_modifierOverrideResolver)
- modifier = (*m_modifierOverrideResolver)(modifier->getName());
- solAssert(modifier, "Error finding override for modifier " + modifier->getName());
- addNode(*modifier);
- }
- return true;
-}
-
-bool CallGraph::visit(FunctionDefinition const& _function)
-{
- m_functionsSeen.insert(&_function);
- return true;
-}
-
-bool CallGraph::visit(MemberAccess const& _memberAccess)
-{
- // used for "BaseContract.baseContractFunction"
- if (_memberAccess.getExpression().getType()->getCategory() == Type::Category::TYPE)
- {
- TypeType const& type = dynamic_cast(*_memberAccess.getExpression().getType());
- if (type.getMembers().getMemberType(_memberAccess.getMemberName()))
- {
- ContractDefinition const& contract = dynamic_cast(*type.getActualType())
- .getContractDefinition();
- for (ASTPointer const& function: contract.getDefinedFunctions())
- if (function->getName() == _memberAccess.getMemberName())
- {
- addNode(*function);
- return true;
- }
- }
- }
- return true;
-}
-
-}
-}
diff --git a/libsolidity/CallGraph.h b/libsolidity/CallGraph.h
deleted file mode 100644
index 9af5cdf97..000000000
--- a/libsolidity/CallGraph.h
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- This file is part of cpp-ethereum.
-
- cpp-ethereum is free software: you can redistribute it and/or modify
- it under the terms of the GNU General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
-
- cpp-ethereum is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU General Public License for more details.
-
- You should have received a copy of the GNU General Public License
- along with cpp-ethereum. If not, see .
-*/
-/**
- * @author Christian
- * @date 2014
- * Callgraph of functions inside a contract.
- */
-
-#include
-#include
-#include
-#include
-#include
-
-namespace dev
-{
-namespace solidity
-{
-
-/**
- * Can be used to compute the graph of calls (or rather references) between functions of the same
- * contract. Current functionality is limited to computing all functions that are directly
- * or indirectly called by some functions.
- */
-class CallGraph: private ASTConstVisitor
-{
-public:
- using FunctionOverrideResolver = std::function;
- using ModifierOverrideResolver = std::function;
-
- CallGraph(FunctionOverrideResolver const& _functionOverrideResolver,
- ModifierOverrideResolver const& _modifierOverrideResolver):
- m_functionOverrideResolver(&_functionOverrideResolver),
- m_modifierOverrideResolver(&_modifierOverrideResolver) {}
-
- void addNode(ASTNode const& _node);
-
- std::set const& getCalls();
-
-private:
- virtual bool visit(FunctionDefinition const& _function) override;
- virtual bool visit(Identifier const& _identifier) override;
- virtual bool visit(MemberAccess const& _memberAccess) override;
-
- void computeCallGraph();
-
- FunctionOverrideResolver const* m_functionOverrideResolver;
- ModifierOverrideResolver const* m_modifierOverrideResolver;
- std::set m_nodesSeen;
- std::set m_functionsSeen;
- std::queue m_workQueue;
-};
-
-}
-}
diff --git a/libsolidity/Compiler.cpp b/libsolidity/Compiler.cpp
index c7656363a..93784adf2 100644
--- a/libsolidity/Compiler.cpp
+++ b/libsolidity/Compiler.cpp
@@ -28,7 +28,6 @@
#include
#include
#include
-#include
using namespace std;
@@ -40,31 +39,13 @@ void Compiler::compileContract(ContractDefinition const& _contract,
{
m_context = CompilerContext(); // clear it just in case
initializeContext(_contract, _contracts);
-
- for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
- {
- for (ASTPointer const& function: contract->getDefinedFunctions())
- if (!function->isConstructor())
- m_context.addFunction(*function);
-
- for (ASTPointer const& vardecl: contract->getStateVariables())
- if (vardecl->isPublic())
- m_context.addFunction(*vardecl);
-
- for (ASTPointer const& modifier: contract->getFunctionModifiers())
- m_context.addModifier(*modifier);
- }
-
appendFunctionSelector(_contract);
- for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
+ set functions = m_context.getFunctionsWithoutCode();
+ while (!functions.empty())
{
- for (ASTPointer const& function: contract->getDefinedFunctions())
- if (!function->isConstructor())
- function->accept(*this);
-
- for (ASTPointer const& vardecl: contract->getStateVariables())
- if (vardecl->isPublic())
- generateAccessorCode(*vardecl);
+ for (Declaration const* function: functions)
+ function->accept(*this);
+ functions = m_context.getFunctionsWithoutCode();
}
// Swap the runtime context with the creation-time context
@@ -77,72 +58,26 @@ void Compiler::initializeContext(ContractDefinition const& _contract,
map const& _contracts)
{
m_context.setCompiledContracts(_contracts);
+ m_context.setInheritanceHierarchy(_contract.getLinearizedBaseContracts());
registerStateVariables(_contract);
}
void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext)
{
- std::vector const& bases = _contract.getLinearizedBaseContracts();
-
- // Make all modifiers known to the context.
- for (ContractDefinition const* contract: bases)
- for (ASTPointer const& modifier: contract->getFunctionModifiers())
- m_context.addModifier(*modifier);
-
// arguments for base constructors, filled in derived-to-base order
map> const*> baseArguments;
- set neededFunctions;
- set nodesUsedInConstructors;
- // Determine the arguments that are used for the base constructors and also which functions
- // are needed at compile time.
+ // Determine the arguments that are used for the base constructors.
+ std::vector const& bases = _contract.getLinearizedBaseContracts();
for (ContractDefinition const* contract: bases)
- {
- if (FunctionDefinition const* constructor = contract->getConstructor())
- nodesUsedInConstructors.insert(constructor);
for (ASTPointer const& base: contract->getBaseContracts())
{
ContractDefinition const* baseContract = dynamic_cast(
base->getName()->getReferencedDeclaration());
solAssert(baseContract, "");
if (baseArguments.count(baseContract) == 0)
- {
baseArguments[baseContract] = &base->getArguments();
- for (ASTPointer const& arg: base->getArguments())
- nodesUsedInConstructors.insert(arg.get());
- }
}
- }
-
- auto functionOverrideResolver = [&](string const& _name) -> FunctionDefinition const*
- {
- for (ContractDefinition const* contract: bases)
- for (ASTPointer const& function: contract->getDefinedFunctions())
- if (!function->isConstructor() && function->getName() == _name)
- return function.get();
- return nullptr;
- };
- auto modifierOverrideResolver = [&](string const& _name) -> ModifierDefinition const*
- {
- return &m_context.getFunctionModifier(_name);
- };
-
- neededFunctions = getFunctionsCalled(nodesUsedInConstructors, functionOverrideResolver,
- modifierOverrideResolver);
-
- // First add all overrides (or the functions themselves if there is no override)
- for (FunctionDefinition const* fun: neededFunctions)
- {
- FunctionDefinition const* override = nullptr;
- if (!fun->isConstructor())
- override = functionOverrideResolver(fun->getName());
- if (!!override && neededFunctions.count(override))
- m_context.addFunction(*override);
- }
- // now add the rest
- for (FunctionDefinition const* fun: neededFunctions)
- if (fun->isConstructor() || functionOverrideResolver(fun->getName()) != fun)
- m_context.addFunction(*fun);
// Call constructors in base-to-derived order.
// The Constructor for the most derived contract is called later.
@@ -164,10 +99,14 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
m_context << eth::Instruction::DUP1 << sub << u256(0) << eth::Instruction::CODECOPY;
m_context << u256(0) << eth::Instruction::RETURN;
- // note that we have to explicitly include all used functions because of absolute jump
- // labels
- for (FunctionDefinition const* fun: neededFunctions)
- fun->accept(*this);
+ // note that we have to include the functions again because of absolute jump labels
+ set functions = m_context.getFunctionsWithoutCode();
+ while (!functions.empty())
+ {
+ for (Declaration const* function: functions)
+ function->accept(*this);
+ functions = m_context.getFunctionsWithoutCode();
+ }
}
void Compiler::appendBaseConstructorCall(FunctionDefinition const& _constructor,
@@ -201,16 +140,6 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
m_context << returnTag;
}
-set Compiler::getFunctionsCalled(set const& _nodes,
- function const& _resolveFunctionOverrides,
- function const& _resolveModifierOverrides)
-{
- CallGraph callgraph(_resolveFunctionOverrides, _resolveModifierOverrides);
- for (ASTNode const* node: _nodes)
- callgraph.addNode(*node);
- return callgraph.getCalls();
-}
-
void Compiler::appendFunctionSelector(ContractDefinition const& _contract)
{
map, FunctionDescription> interfaceFunctions = _contract.getInterfaceFunctions();
@@ -292,19 +221,22 @@ void Compiler::registerStateVariables(ContractDefinition const& _contract)
m_context.addStateVariable(*variable);
}
-void Compiler::generateAccessorCode(VariableDeclaration const& _varDecl)
+bool Compiler::visit(VariableDeclaration const& _variableDeclaration)
{
- m_context.startNewFunction();
+ solAssert(_variableDeclaration.isStateVariable(), "Compiler visit to non-state variable declaration.");
+
+ m_context.startFunction(_variableDeclaration);
m_breakTags.clear();
m_continueTags.clear();
- m_context << m_context.getFunctionEntryLabel(_varDecl);
- ExpressionCompiler::appendStateVariableAccessor(m_context, _varDecl);
+ m_context << m_context.getFunctionEntryLabel(_variableDeclaration);
+ ExpressionCompiler::appendStateVariableAccessor(m_context, _variableDeclaration);
- unsigned sizeOnStack = _varDecl.getType()->getSizeOnStack();
- solAssert(sizeOnStack <= 15, "Illegal variable stack size detected");
- m_context << eth::dupInstruction(sizeOnStack + 1);
- m_context << eth::Instruction::JUMP;
+ unsigned sizeOnStack = _variableDeclaration.getType()->getSizeOnStack();
+ solAssert(sizeOnStack <= 15, "Stack too deep.");
+ m_context << eth::dupInstruction(sizeOnStack + 1) << eth::Instruction::JUMP;
+
+ return false;
}
bool Compiler::visit(FunctionDefinition const& _function)
@@ -313,7 +245,7 @@ bool Compiler::visit(FunctionDefinition const& _function)
// caller puts: [retarg0] ... [retargm] [return address] [arg0] ... [argn]
// although note that this reduces the size of the visible stack
- m_context.startNewFunction();
+ m_context.startFunction(_function);
m_returnTag = m_context.newTag();
m_breakTags.clear();
m_continueTags.clear();
@@ -321,8 +253,6 @@ bool Compiler::visit(FunctionDefinition const& _function)
m_currentFunction = &_function;
m_modifierDepth = 0;
- m_context << m_context.getFunctionEntryLabel(_function);
-
// stack upon entry: [return address] [arg0] [arg1] ... [argn]
// reserve additional slots: [retarg0] ... [retargm] [localvar0] ... [localvarp]
diff --git a/libsolidity/Compiler.h b/libsolidity/Compiler.h
index 144af8eb8..b3eae5b17 100644
--- a/libsolidity/Compiler.h
+++ b/libsolidity/Compiler.h
@@ -50,11 +50,6 @@ private:
void appendBaseConstructorCall(FunctionDefinition const& _constructor,
std::vector> const& _arguments);
void appendConstructorCall(FunctionDefinition const& _constructor);
- /// Recursively searches the call graph and returns all functions referenced inside _nodes.
- /// _resolveOverride is called to resolve virtual function overrides.
- std::set getFunctionsCalled(std::set const& _nodes,
- std::function const& _resolveFunctionOverride,
- std::function const& _resolveModifierOverride);
void appendFunctionSelector(ContractDefinition const& _contract);
/// Creates code that unpacks the arguments for the given function represented by a vector of TypePointers.
/// From memory if @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
@@ -63,8 +58,7 @@ private:
void registerStateVariables(ContractDefinition const& _contract);
- void generateAccessorCode(VariableDeclaration const& _varDecl);
-
+ virtual bool visit(VariableDeclaration const& _variableDeclaration) override;
virtual bool visit(FunctionDefinition const& _function) override;
virtual bool visit(IfStatement const& _ifStatement) override;
virtual bool visit(WhileStatement const& _whileStatement) override;
diff --git a/libsolidity/CompilerContext.cpp b/libsolidity/CompilerContext.cpp
index ea349c0d2..52910a556 100644
--- a/libsolidity/CompilerContext.cpp
+++ b/libsolidity/CompilerContext.cpp
@@ -43,6 +43,14 @@ void CompilerContext::addStateVariable(VariableDeclaration const& _declaration)
m_stateVariablesSize += _declaration.getType()->getStorageSize();
}
+void CompilerContext::startFunction(Declaration const& _function)
+{
+ m_functionsWithCode.insert(&_function);
+ m_localVariables.clear();
+ m_asm.setDeposit(0);
+ *this << getFunctionEntryLabel(_function);
+}
+
void CompilerContext::addVariable(VariableDeclaration const& _declaration,
unsigned _offsetToCurrent)
{
@@ -59,18 +67,6 @@ void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _decla
*this << u256(0);
}
-void CompilerContext::addFunction(Declaration const& _decl)
-{
- eth::AssemblyItem tag(m_asm.newTag());
- m_functionEntryLabels.insert(make_pair(&_decl, tag));
- m_virtualFunctionEntryLabels.insert(make_pair(_decl.getName(), tag));
-}
-
-void CompilerContext::addModifier(ModifierDefinition const& _modifier)
-{
- m_functionModifiers.insert(make_pair(_modifier.getName(), &_modifier));
-}
-
bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const
{
auto ret = m_compiledContracts.find(&_contract);
@@ -83,25 +79,62 @@ bool CompilerContext::isLocalVariable(Declaration const* _declaration) const
return m_localVariables.count(_declaration);
}
-eth::AssemblyItem CompilerContext::getFunctionEntryLabel(Declaration const& _declaration) const
+eth::AssemblyItem CompilerContext::getFunctionEntryLabel(Declaration const& _declaration)
{
auto res = m_functionEntryLabels.find(&_declaration);
- solAssert(res != m_functionEntryLabels.end(), "Function entry label not found.");
- return res->second.tag();
+ if (res == m_functionEntryLabels.end())
+ {
+ eth::AssemblyItem tag(m_asm.newTag());
+ m_functionEntryLabels.insert(make_pair(&_declaration, tag));
+ return tag.tag();
+ }
+ else
+ return res->second.tag();
+}
+
+eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function)
+{
+ solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set.");
+ for (ContractDefinition const* contract: m_inheritanceHierarchy)
+ for (ASTPointer const& function: contract->getDefinedFunctions())
+ if (!function->isConstructor() && function->getName() == _function.getName())
+ return getFunctionEntryLabel(*function);
+ solAssert(false, "Virtual function " + _function.getName() + " not found.");
+ return m_asm.newTag(); // not reached
+}
+
+eth::AssemblyItem CompilerContext::getSuperFunctionEntryLabel(string const& _name, ContractDefinition const& _base)
+{
+ // search for first contract after _base
+ solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set.");
+ auto it = find(m_inheritanceHierarchy.begin(), m_inheritanceHierarchy.end(), &_base);
+ solAssert(it != m_inheritanceHierarchy.end(), "Base not found in inheritance hierarchy.");
+ for (++it; it != m_inheritanceHierarchy.end(); ++it)
+ for (ASTPointer const& function: (*it)->getDefinedFunctions())
+ if (!function->isConstructor() && function->getName() == _name)
+ return getFunctionEntryLabel(*function);
+ solAssert(false, "Super function " + _name + " not found.");
+ return m_asm.newTag(); // not reached
}
-eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const
+set CompilerContext::getFunctionsWithoutCode()
{
- auto res = m_virtualFunctionEntryLabels.find(_function.getName());
- solAssert(res != m_virtualFunctionEntryLabels.end(), "Function entry label not found.");
- return res->second.tag();
+ set functions;
+ for (auto const& it: m_functionEntryLabels)
+ if (m_functionsWithCode.count(it.first) == 0)
+ functions.insert(it.first);
+ return move(functions);
}
ModifierDefinition const& CompilerContext::getFunctionModifier(string const& _name) const
{
- auto res = m_functionModifiers.find(_name);
- solAssert(res != m_functionModifiers.end(), "Function modifier override not found.");
- return *res->second;
+ solAssert(!m_inheritanceHierarchy.empty(), "No inheritance hierarchy set.");
+ for (ContractDefinition const* contract: m_inheritanceHierarchy)
+ for (ASTPointer const& modifier: contract->getFunctionModifiers())
+ if (modifier->getName() == _name)
+ return *modifier.get();
+ BOOST_THROW_EXCEPTION(InternalCompilerError()
+ << errinfo_comment("Function modifier " + _name + " not found."));
}
unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const
diff --git a/libsolidity/CompilerContext.h b/libsolidity/CompilerContext.h
index 9de3385a6..6d6a65b61 100644
--- a/libsolidity/CompilerContext.h
+++ b/libsolidity/CompilerContext.h
@@ -41,12 +41,8 @@ class CompilerContext
public:
void addMagicGlobal(MagicVariableDeclaration const& _declaration);
void addStateVariable(VariableDeclaration const& _declaration);
- void startNewFunction() { m_localVariables.clear(); m_asm.setDeposit(0); }
void addVariable(VariableDeclaration const& _declaration, unsigned _offsetToCurrent = 0);
void addAndInitializeVariable(VariableDeclaration const& _declaration);
- void addFunction(Declaration const& _decl);
- /// Adds the given modifier to the list by name if the name is not present already.
- void addModifier(ModifierDefinition const& _modifier);
void setCompiledContracts(std::map const& _contracts) { m_compiledContracts = _contracts; }
bytes const& getCompiledContract(ContractDefinition const& _contract) const;
@@ -54,13 +50,22 @@ public:
void adjustStackOffset(int _adjustment) { m_asm.adjustDeposit(_adjustment); }
bool isMagicGlobal(Declaration const* _declaration) const { return m_magicGlobals.count(_declaration) != 0; }
- bool isFunctionDefinition(Declaration const* _declaration) const { return m_functionEntryLabels.count(_declaration) != 0; }
bool isLocalVariable(Declaration const* _declaration) const;
bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; }
- eth::AssemblyItem getFunctionEntryLabel(Declaration const& _declaration) const;
+ eth::AssemblyItem getFunctionEntryLabel(Declaration const& _declaration);
+ void setInheritanceHierarchy(std::vector const& _hierarchy) { m_inheritanceHierarchy = _hierarchy; }
/// @returns the entry label of the given function and takes overrides into account.
- eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const;
+ eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function);
+ /// @returns the entry label of function with the given name from the most derived class just
+ /// above _base in the current inheritance hierarchy.
+ eth::AssemblyItem getSuperFunctionEntryLabel(std::string const& _name, ContractDefinition const& _base);
+ /// @returns the set of functions for which we still need to generate code
+ std::set getFunctionsWithoutCode();
+ /// Resets function specific members, inserts the function entry label and marks the function
+ /// as "having code".
+ void startFunction(Declaration const& _function);
+
ModifierDefinition const& getFunctionModifier(std::string const& _name) const;
/// Returns the distance of the given local variable from the bottom of the stack (of the current function).
unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const;
@@ -119,10 +124,10 @@ private:
std::map m_localVariables;
/// Labels pointing to the entry points of functions.
std::map m_functionEntryLabels;
- /// Labels pointing to the entry points of function overrides.
- std::map m_virtualFunctionEntryLabels;
- /// Mapping to obtain function modifiers by name. Should be filled from derived to base.
- std::map m_functionModifiers;
+ /// Set of functions for which we did not yet generate code.
+ std::set m_functionsWithCode;
+ /// List of current inheritance hierarchy from derived to base.
+ std::vector m_inheritanceHierarchy;
};
}
diff --git a/libsolidity/CompilerStack.cpp b/libsolidity/CompilerStack.cpp
index 0b8218bb3..3ed0d3620 100644
--- a/libsolidity/CompilerStack.cpp
+++ b/libsolidity/CompilerStack.cpp
@@ -94,6 +94,7 @@ void CompilerStack::parse()
{
m_globalContext->setCurrentContract(*contract);
resolver.updateDeclaration(*m_globalContext->getCurrentThis());
+ resolver.updateDeclaration(*m_globalContext->getCurrentSuper());
resolver.resolveNamesAndTypes(*contract);
m_contracts[contract->getName()].contract = contract;
}
diff --git a/libsolidity/ExpressionCompiler.cpp b/libsolidity/ExpressionCompiler.cpp
index 15ee17fd3..d2f709be0 100644
--- a/libsolidity/ExpressionCompiler.cpp
+++ b/libsolidity/ExpressionCompiler.cpp
@@ -366,14 +366,22 @@ void ExpressionCompiler::endVisit(MemberAccess const& _memberAccess)
case Type::Category::CONTRACT:
{
ContractType const& type = dynamic_cast(*_memberAccess.getExpression().getType());
- u256 identifier = type.getFunctionIdentifier(member);
- if (identifier != Invalid256)
+ if (type.isSuper())
{
- appendTypeConversion(type, IntegerType(0, IntegerType::Modifier::ADDRESS), true);
- m_context << identifier;
+ m_context << m_context.getSuperFunctionEntryLabel(member, type.getContractDefinition()).pushTag();
break;
}
- // fall-through to "integer" otherwise (address)
+ else
+ {
+ u256 identifier = type.getFunctionIdentifier(member);
+ if (identifier != Invalid256)
+ {
+ appendTypeConversion(type, IntegerType(0, IntegerType::Modifier::ADDRESS), true);
+ m_context << identifier;
+ break;
+ }
+ // fall-through to "integer" otherwise (address)
+ }
}
case Type::Category::INTEGER:
if (member == "balance")
@@ -469,8 +477,10 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier)
Declaration const* declaration = _identifier.getReferencedDeclaration();
if (MagicVariableDeclaration const* magicVar = dynamic_cast(declaration))
{
- if (magicVar->getType()->getCategory() == Type::Category::CONTRACT) // must be "this"
- m_context << eth::Instruction::ADDRESS;
+ if (magicVar->getType()->getCategory() == Type::Category::CONTRACT)
+ // "this" or "super"
+ if (!dynamic_cast(*magicVar->getType()).isSuper())
+ m_context << eth::Instruction::ADDRESS;
}
else if (FunctionDefinition const* functionDef = dynamic_cast(declaration))
m_context << m_context.getVirtualFunctionEntryLabel(*functionDef).pushTag();
diff --git a/libsolidity/GlobalContext.cpp b/libsolidity/GlobalContext.cpp
index c7eea92dc..40a498c8e 100644
--- a/libsolidity/GlobalContext.cpp
+++ b/libsolidity/GlobalContext.cpp
@@ -83,5 +83,14 @@ MagicVariableDeclaration const* GlobalContext::getCurrentThis() const
}
+MagicVariableDeclaration const* GlobalContext::getCurrentSuper() const
+{
+ if (!m_superPointer[m_currentContract])
+ m_superPointer[m_currentContract] = make_shared(
+ "super", make_shared(*m_currentContract, true));
+ return m_superPointer[m_currentContract].get();
+
+}
+
}
}
diff --git a/libsolidity/GlobalContext.h b/libsolidity/GlobalContext.h
index dfdc66623..f861c67d7 100644
--- a/libsolidity/GlobalContext.h
+++ b/libsolidity/GlobalContext.h
@@ -48,6 +48,7 @@ public:
GlobalContext();
void setCurrentContract(ContractDefinition const& _contract);
MagicVariableDeclaration const* getCurrentThis() const;
+ MagicVariableDeclaration const* getCurrentSuper() const;
/// @returns a vector of all implicit global declarations excluding "this".
std::vector getDeclarations() const;
@@ -56,6 +57,7 @@ private:
std::vector> m_magicVariables;
ContractDefinition const* m_currentContract = nullptr;
std::map> mutable m_thisPointer;
+ std::map> mutable m_superPointer;
};
}
diff --git a/libsolidity/Types.cpp b/libsolidity/Types.cpp
index fcb10d4b5..3d6c4e96c 100644
--- a/libsolidity/Types.cpp
+++ b/libsolidity/Types.cpp
@@ -450,7 +450,9 @@ bool ContractType::isImplicitlyConvertibleTo(Type const& _convertTo) const
if (_convertTo.getCategory() == Category::CONTRACT)
{
auto const& bases = getContractDefinition().getLinearizedBaseContracts();
- return find(bases.begin(), bases.end(),
+ if (m_super && bases.size() <= 1)
+ return false;
+ return find(m_super ? ++bases.begin() : bases.begin(), bases.end(),
&dynamic_cast(_convertTo).getContractDefinition()) != bases.end();
}
return false;
@@ -472,12 +474,12 @@ bool ContractType::operator==(Type const& _other) const
if (_other.getCategory() != getCategory())
return false;
ContractType const& other = dynamic_cast(_other);
- return other.m_contract == m_contract;
+ return other.m_contract == m_contract && other.m_super == m_super;
}
string ContractType::toString() const
{
- return "contract " + m_contract.getName();
+ return "contract " + string(m_super ? "super " : "") + m_contract.getName();
}
MemberList const& ContractType::getMembers() const
@@ -488,8 +490,16 @@ MemberList const& ContractType::getMembers() const
// All address members and all interface functions
map> members(IntegerType::AddressMemberList.begin(),
IntegerType::AddressMemberList.end());
- for (auto const& it: m_contract.getInterfaceFunctions())
- members[it.second.getName()] = it.second.getFunctionTypeShared();
+ if (m_super)
+ {
+ for (ContractDefinition const* base: m_contract.getLinearizedBaseContracts())
+ for (ASTPointer const& function: base->getDefinedFunctions())
+ if (!function->isConstructor())
+ members.insert(make_pair(function->getName(), make_shared(*function, true)));
+ }
+ else
+ for (auto const& it: m_contract.getInterfaceFunctions())
+ members[it.second.getName()] = it.second.getFunctionTypeShared();
m_members.reset(new MemberList(members));
}
return *m_members;
diff --git a/libsolidity/Types.h b/libsolidity/Types.h
index 2dbed95fa..3f6df13ee 100644
--- a/libsolidity/Types.h
+++ b/libsolidity/Types.h
@@ -277,7 +277,8 @@ class ContractType: public Type
{
public:
virtual Category getCategory() const override { return Category::CONTRACT; }
- ContractType(ContractDefinition const& _contract): m_contract(_contract) {}
+ explicit ContractType(ContractDefinition const& _contract, bool _super = false):
+ m_contract(_contract), m_super(_super) {}
/// Contracts can be implicitly converted to super classes and to addresses.
virtual bool isImplicitlyConvertibleTo(Type const& _convertTo) const override;
/// Contracts can be converted to themselves and to integers.
@@ -289,6 +290,7 @@ public:
virtual MemberList const& getMembers() const override;
+ bool isSuper() const { return m_super; }
ContractDefinition const& getContractDefinition() const { return m_contract; }
/// Returns the function type of the constructor. Note that the location part of the function type
@@ -301,6 +303,9 @@ public:
private:
ContractDefinition const& m_contract;
+ /// If true, it is the "super" type of the current contract, i.e. it contains only inherited
+ /// members.
+ bool m_super;
/// Type of the constructor, @see getConstructorType. Lazily initialized.
mutable std::shared_ptr m_constructorType;
/// List of member types, will be lazy-initialized because of recursive references.
@@ -314,7 +319,7 @@ class StructType: public Type
{
public:
virtual Category getCategory() const override { return Category::STRUCT; }
- StructType(StructDefinition const& _struct): m_struct(_struct) {}
+ explicit StructType(StructDefinition const& _struct): m_struct(_struct) {}
virtual TypePointer unaryOperatorResult(Token::Value _operator) const override;
virtual bool operator==(Type const& _other) const override;
virtual u256 getStorageSize() const override;
@@ -448,7 +453,7 @@ class TypeType: public Type
{
public:
virtual Category getCategory() const override { return Category::TYPE; }
- TypeType(TypePointer const& _actualType, ContractDefinition const* _currentContract = nullptr):
+ explicit TypeType(TypePointer const& _actualType, ContractDefinition const* _currentContract = nullptr):
m_actualType(_actualType), m_currentContract(_currentContract) {}
TypePointer const& getActualType() const { return m_actualType; }
@@ -502,7 +507,7 @@ public:
enum class Kind { BLOCK, MSG, TX };
virtual Category getCategory() const override { return Category::MAGIC; }
- MagicType(Kind _kind);
+ explicit MagicType(Kind _kind);
virtual TypePointer binaryOperatorResult(Token::Value, TypePointer const&) const override
{
diff --git a/test/SolidityCompiler.cpp b/test/SolidityCompiler.cpp
index 53daa9dfe..e7e64805a 100644
--- a/test/SolidityCompiler.cpp
+++ b/test/SolidityCompiler.cpp
@@ -111,8 +111,8 @@ BOOST_AUTO_TEST_CASE(smoke_test)
BOOST_AUTO_TEST_CASE(different_argument_numbers)
{
char const* sourceCode = "contract test {\n"
- " function f(uint a, uint b, uint c) returns(uint d) { return b; }\n"
" function g() returns (uint e, uint h) { h = f(1, 2, 3); }\n"
+ " function f(uint a, uint b, uint c) returns(uint d) { return b; }\n"
"}\n";
bytes code = compileContract(sourceCode);
unsigned shift = 103;
diff --git a/test/SolidityEndToEndTest.cpp b/test/SolidityEndToEndTest.cpp
index 1ddb22731..1450095af 100644
--- a/test/SolidityEndToEndTest.cpp
+++ b/test/SolidityEndToEndTest.cpp
@@ -1930,6 +1930,30 @@ BOOST_AUTO_TEST_CASE(crazy_elementary_typenames_on_stack)
BOOST_CHECK(callContractFunction("f()") == encodeArgs(u256(-7)));
}
+BOOST_AUTO_TEST_CASE(super)
+{
+ char const* sourceCode = R"(
+ contract A { function f() returns (uint r) { return 1; } }
+ contract B is A { function f() returns (uint r) { return super.f() | 2; } }
+ contract C is A { function f() returns (uint r) { return super.f() | 4; } }
+ contract D is B, C { function f() returns (uint r) { return super.f() | 8; } }
+ )";
+ compileAndRun(sourceCode, 0, "D");
+ BOOST_CHECK(callContractFunction("f()") == encodeArgs(1 | 2 | 4 | 8));
+}
+
+BOOST_AUTO_TEST_CASE(super_in_constructor)
+{
+ char const* sourceCode = R"(
+ contract A { function f() returns (uint r) { return 1; } }
+ contract B is A { function f() returns (uint r) { return super.f() | 2; } }
+ contract C is A { function f() returns (uint r) { return super.f() | 4; } }
+ contract D is B, C { uint data; function D() { data = super.f() | 8; } function f() returns (uint r) { return data; } }
+ )";
+ compileAndRun(sourceCode, 0, "D");
+ BOOST_CHECK(callContractFunction("f()") == encodeArgs(1 | 2 | 4 | 8));
+}
+
BOOST_AUTO_TEST_SUITE_END()
}
diff --git a/test/SolidityExpressionCompiler.cpp b/test/SolidityExpressionCompiler.cpp
index 06c252db3..a0cca3a3a 100644
--- a/test/SolidityExpressionCompiler.cpp
+++ b/test/SolidityExpressionCompiler.cpp
@@ -86,7 +86,8 @@ Declaration const& resolveDeclaration(vector const& _namespacedName,
}
bytes compileFirstExpression(const string& _sourceCode, vector> _functions = {},
- vector> _localVariables = {}, vector> _globalDeclarations = {})
+ vector> _localVariables = {},
+ vector> _globalDeclarations = {})
{
Parser parser;
ASTPointer sourceUnit;
@@ -99,10 +100,12 @@ bytes compileFirstExpression(const string& _sourceCode, vector> _
NameAndTypeResolver resolver(declarations);
resolver.registerDeclarations(*sourceUnit);
+ vector inheritanceHierarchy;
for (ASTPointer const& node: sourceUnit->getNodes())
if (ContractDefinition* contract = dynamic_cast(node.get()))
{
BOOST_REQUIRE_NO_THROW(resolver.resolveNamesAndTypes(*contract));
+ inheritanceHierarchy = vector(1, contract);
}
for (ASTPointer const& node: sourceUnit->getNodes())
if (ContractDefinition* contract = dynamic_cast(node.get()))
@@ -116,8 +119,7 @@ bytes compileFirstExpression(const string& _sourceCode, vector> _
BOOST_REQUIRE(extractor.getExpression() != nullptr);
CompilerContext context;
- for (vector const& function: _functions)
- context.addFunction(dynamic_cast(resolveDeclaration(function, resolver)));
+ context.setInheritanceHierarchy(inheritanceHierarchy);
unsigned parametersSize = _localVariables.size(); // assume they are all one slot on the stack
context.adjustStackOffset(parametersSize);
for (vector const& variable: _localVariables)