Browse Source

Merge pull request #825 from chriseth/sol_contractInheritance

Contract inheritance
cl-refactor
Gav Wood 10 years ago
parent
commit
471042fea6
  1. 65
      libsolidity/AST.cpp
  2. 59
      libsolidity/AST.h
  3. 1
      libsolidity/ASTForward.h
  4. 22
      libsolidity/AST_accept.h
  5. 44
      libsolidity/CallGraph.cpp
  6. 7
      libsolidity/CallGraph.h
  7. 93
      libsolidity/Compiler.cpp
  8. 15
      libsolidity/Compiler.h
  9. 11
      libsolidity/CompilerContext.cpp
  10. 4
      libsolidity/CompilerContext.h
  11. 4
      libsolidity/CompilerStack.cpp
  12. 5
      libsolidity/DeclarationContainer.h
  13. 34
      libsolidity/ExpressionCompiler.cpp
  14. 12
      libsolidity/GlobalContext.cpp
  15. 2
      libsolidity/GlobalContext.h
  16. 116
      libsolidity/NameAndTypeResolver.cpp
  17. 19
      libsolidity/NameAndTypeResolver.h
  18. 30
      libsolidity/Parser.cpp
  19. 1
      libsolidity/Parser.h
  20. 2
      libsolidity/Token.h
  21. 32
      libsolidity/Types.cpp
  22. 8
      libsolidity/Types.h
  23. 5
      libsolidity/grammar.txt
  24. 2
      test/SolidityCompiler.cpp
  25. 135
      test/SolidityEndToEndTest.cpp
  26. 123
      test/SolidityNameAndTypeResolution.cpp
  27. 45
      test/SolidityParser.cpp

65
libsolidity/AST.cpp

@ -43,6 +43,11 @@ TypeError ASTNode::createTypeError(string const& _description) const
void ContractDefinition::checkTypeRequirements() void ContractDefinition::checkTypeRequirements()
{ {
for (ASTPointer<InheritanceSpecifier> const& base: getBaseContracts())
base->checkTypeRequirements();
checkIllegalOverrides();
FunctionDefinition const* constructor = getConstructor(); FunctionDefinition const* constructor = getConstructor();
if (constructor && !constructor->getReturnParameters().empty()) if (constructor && !constructor->getReturnParameters().empty())
BOOST_THROW_EXCEPTION(constructor->getReturnParameterList()->createTypeError( BOOST_THROW_EXCEPTION(constructor->getReturnParameterList()->createTypeError(
@ -52,7 +57,6 @@ void ContractDefinition::checkTypeRequirements()
function->checkTypeRequirements(); function->checkTypeRequirements();
// check for hash collisions in function signatures // check for hash collisions in function signatures
vector<pair<FixedHash<4>, FunctionDefinition const*>> exportedFunctionList = getInterfaceFunctionList();
set<FixedHash<4>> hashes; set<FixedHash<4>> hashes;
for (auto const& hashAndFunction: getInterfaceFunctionList()) for (auto const& hashAndFunction: getInterfaceFunctionList())
{ {
@ -83,17 +87,59 @@ FunctionDefinition const* ContractDefinition::getConstructor() const
return nullptr; return nullptr;
} }
vector<pair<FixedHash<4>, FunctionDefinition const*>> ContractDefinition::getInterfaceFunctionList() const void ContractDefinition::checkIllegalOverrides() const
{ {
vector<pair<FixedHash<4>, FunctionDefinition const*>> exportedFunctions; map<string, FunctionDefinition const*> functions;
for (ASTPointer<FunctionDefinition> const& f: m_definedFunctions)
if (f->isPublic() && f->getName() != getName()) // We search from derived to base, so the stored item causes the error.
for (ContractDefinition const* contract: getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
{
if (function->getName() == contract->getName())
continue; // constructors can neither be overriden nor override anything
FunctionDefinition const*& override = functions[function->getName()];
if (!override)
override = function.get();
else if (override->isPublic() != function->isPublic() ||
override->isDeclaredConst() != function->isDeclaredConst() ||
FunctionType(*override) != FunctionType(*function))
BOOST_THROW_EXCEPTION(override->createTypeError("Override changes extended function signature."));
}
}
vector<pair<FixedHash<4>, FunctionDefinition const*>> const& ContractDefinition::getInterfaceFunctionList() const
{
if (!m_interfaceFunctionList)
{
set<string> functionsSeen;
m_interfaceFunctionList.reset(new vector<pair<FixedHash<4>, FunctionDefinition const*>>());
for (ContractDefinition const* contract: getLinearizedBaseContracts())
for (ASTPointer<FunctionDefinition> const& f: contract->getDefinedFunctions())
if (f->isPublic() && f->getName() != contract->getName() &&
functionsSeen.count(f->getName()) == 0)
{ {
functionsSeen.insert(f->getName());
FixedHash<4> hash(dev::sha3(f->getCanonicalSignature())); FixedHash<4> hash(dev::sha3(f->getCanonicalSignature()));
exportedFunctions.push_back(make_pair(hash, f.get())); m_interfaceFunctionList->push_back(make_pair(hash, f.get()));
} }
}
return *m_interfaceFunctionList;
}
return exportedFunctions; void InheritanceSpecifier::checkTypeRequirements()
{
m_baseName->checkTypeRequirements();
for (ASTPointer<Expression> const& argument: m_arguments)
argument->checkTypeRequirements();
ContractDefinition const* base = dynamic_cast<ContractDefinition const*>(m_baseName->getReferencedDeclaration());
solAssert(base, "Base contract not available.");
TypePointers parameterTypes = ContractType(*base).getConstructorType()->getParameterTypes();
if (parameterTypes.size() != m_arguments.size())
BOOST_THROW_EXCEPTION(createTypeError("Wrong argument count for constructor call."));
for (size_t i = 0; i < m_arguments.size(); ++i)
if (!m_arguments[i]->getType()->isImplicitlyConvertibleTo(*parameterTypes[i]))
BOOST_THROW_EXCEPTION(createTypeError("Invalid type for argument in constructer call."));
} }
void StructDefinition::checkMemberTypes() const void StructDefinition::checkMemberTypes() const
@ -346,7 +392,8 @@ void MemberAccess::checkTypeRequirements()
Type const& type = *m_expression->getType(); Type const& type = *m_expression->getType();
m_type = type.getMemberType(*m_memberName); m_type = type.getMemberType(*m_memberName);
if (!m_type) if (!m_type)
BOOST_THROW_EXCEPTION(createTypeError("Member \"" + *m_memberName + "\" not found in " + type.toString())); BOOST_THROW_EXCEPTION(createTypeError("Member \"" + *m_memberName + "\" not found or not "
"visible in " + type.toString()));
//@todo later, this will not always be STORAGE //@todo later, this will not always be STORAGE
m_lvalue = type.getCategory() == Type::Category::STRUCT ? LValueType::STORAGE : LValueType::NONE; m_lvalue = type.getCategory() == Type::Category::STRUCT ? LValueType::STORAGE : LValueType::NONE;
} }
@ -396,7 +443,7 @@ void Identifier::checkTypeRequirements()
ContractDefinition const* contractDef = dynamic_cast<ContractDefinition const*>(m_referencedDeclaration); ContractDefinition const* contractDef = dynamic_cast<ContractDefinition const*>(m_referencedDeclaration);
if (contractDef) if (contractDef)
{ {
m_type = make_shared<TypeType>(make_shared<ContractType>(*contractDef)); m_type = make_shared<TypeType>(make_shared<ContractType>(*contractDef), m_currentContract);
return; return;
} }
MagicVariableDeclaration const* magicVariable = dynamic_cast<MagicVariableDeclaration const*>(m_referencedDeclaration); MagicVariableDeclaration const* magicVariable = dynamic_cast<MagicVariableDeclaration const*>(m_referencedDeclaration);

59
libsolidity/AST.h

@ -158,10 +158,12 @@ public:
ContractDefinition(Location const& _location, ContractDefinition(Location const& _location,
ASTPointer<ASTString> const& _name, ASTPointer<ASTString> const& _name,
ASTPointer<ASTString> const& _documentation, ASTPointer<ASTString> const& _documentation,
std::vector<ASTPointer<InheritanceSpecifier>> const& _baseContracts,
std::vector<ASTPointer<StructDefinition>> const& _definedStructs, std::vector<ASTPointer<StructDefinition>> const& _definedStructs,
std::vector<ASTPointer<VariableDeclaration>> const& _stateVariables, std::vector<ASTPointer<VariableDeclaration>> const& _stateVariables,
std::vector<ASTPointer<FunctionDefinition>> const& _definedFunctions): std::vector<ASTPointer<FunctionDefinition>> const& _definedFunctions):
Declaration(_location, _name), Declaration(_location, _name),
m_baseContracts(_baseContracts),
m_definedStructs(_definedStructs), m_definedStructs(_definedStructs),
m_stateVariables(_stateVariables), m_stateVariables(_stateVariables),
m_definedFunctions(_definedFunctions), m_definedFunctions(_definedFunctions),
@ -171,12 +173,13 @@ public:
virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTVisitor& _visitor) override;
virtual void accept(ASTConstVisitor& _visitor) const override; virtual void accept(ASTConstVisitor& _visitor) const override;
std::vector<ASTPointer<InheritanceSpecifier>> const& getBaseContracts() const { return m_baseContracts; }
std::vector<ASTPointer<StructDefinition>> const& getDefinedStructs() const { return m_definedStructs; } std::vector<ASTPointer<StructDefinition>> const& getDefinedStructs() const { return m_definedStructs; }
std::vector<ASTPointer<VariableDeclaration>> const& getStateVariables() const { return m_stateVariables; } std::vector<ASTPointer<VariableDeclaration>> const& getStateVariables() const { return m_stateVariables; }
std::vector<ASTPointer<FunctionDefinition>> const& getDefinedFunctions() const { return m_definedFunctions; } std::vector<ASTPointer<FunctionDefinition>> const& getDefinedFunctions() const { return m_definedFunctions; }
/// Checks that the constructor does not have a "returns" statement and calls /// Checks that there are no illegal overrides, that the constructor does not have a "returns"
/// checkTypeRequirements on all its functions. /// and calls checkTypeRequirements on all its functions.
void checkTypeRequirements(); void checkTypeRequirements();
/// @return A shared pointer of an ASTString. /// @return A shared pointer of an ASTString.
@ -187,16 +190,47 @@ public:
/// as intended for use by the ABI. /// as intended for use by the ABI.
std::map<FixedHash<4>, FunctionDefinition const*> getInterfaceFunctions() const; std::map<FixedHash<4>, FunctionDefinition const*> getInterfaceFunctions() const;
/// List of all (direct and indirect) base contracts in order from derived to base, including
/// the contract itself. Available after name resolution
std::vector<ContractDefinition const*> const& getLinearizedBaseContracts() const { return m_linearizedBaseContracts; }
void setLinearizedBaseContracts(std::vector<ContractDefinition const*> const& _bases) { m_linearizedBaseContracts = _bases; }
/// Returns the constructor or nullptr if no constructor was specified /// Returns the constructor or nullptr if no constructor was specified
FunctionDefinition const* getConstructor() const; FunctionDefinition const* getConstructor() const;
private: private:
std::vector<std::pair<FixedHash<4>, FunctionDefinition const*>> getInterfaceFunctionList() const; void checkIllegalOverrides() const;
std::vector<std::pair<FixedHash<4>, FunctionDefinition const*>> const& getInterfaceFunctionList() const;
std::vector<ASTPointer<InheritanceSpecifier>> m_baseContracts;
std::vector<ASTPointer<StructDefinition>> m_definedStructs; std::vector<ASTPointer<StructDefinition>> m_definedStructs;
std::vector<ASTPointer<VariableDeclaration>> m_stateVariables; std::vector<ASTPointer<VariableDeclaration>> m_stateVariables;
std::vector<ASTPointer<FunctionDefinition>> m_definedFunctions; std::vector<ASTPointer<FunctionDefinition>> m_definedFunctions;
ASTPointer<ASTString> m_documentation; ASTPointer<ASTString> m_documentation;
std::vector<ContractDefinition const*> m_linearizedBaseContracts;
mutable std::unique_ptr<std::vector<std::pair<FixedHash<4>, FunctionDefinition const*>>> m_interfaceFunctionList;
};
class InheritanceSpecifier: public ASTNode
{
public:
InheritanceSpecifier(Location const& _location, ASTPointer<Identifier> const& _baseName,
std::vector<ASTPointer<Expression>> _arguments):
ASTNode(_location), m_baseName(_baseName), m_arguments(_arguments) {}
virtual void accept(ASTVisitor& _visitor) override;
virtual void accept(ASTConstVisitor& _visitor) const override;
ASTPointer<Identifier> const& getName() const { return m_baseName; }
std::vector<ASTPointer<Expression>> const& getArguments() const { return m_arguments; }
void checkTypeRequirements();
private:
ASTPointer<Identifier> m_baseName;
std::vector<ASTPointer<Expression>> m_arguments;
}; };
class StructDefinition: public Declaration class StructDefinition: public Declaration
@ -581,7 +615,7 @@ public:
virtual void accept(ASTConstVisitor& _visitor) const override; virtual void accept(ASTConstVisitor& _visitor) const override;
virtual void checkTypeRequirements() override; virtual void checkTypeRequirements() override;
void setFunctionReturnParameters(ParameterList& _parameters) { m_returnParameters = &_parameters; } void setFunctionReturnParameters(ParameterList const& _parameters) { m_returnParameters = &_parameters; }
ParameterList const& getFunctionReturnParameters() const ParameterList const& getFunctionReturnParameters() const
{ {
solAssert(m_returnParameters, ""); solAssert(m_returnParameters, "");
@ -593,7 +627,7 @@ private:
ASTPointer<Expression> m_expression; ///< value to return, optional ASTPointer<Expression> m_expression; ///< value to return, optional
/// Pointer to the parameter list of the function, filled by the @ref NameAndTypeResolver. /// Pointer to the parameter list of the function, filled by the @ref NameAndTypeResolver.
ParameterList* m_returnParameters; ParameterList const* m_returnParameters;
}; };
/** /**
@ -870,21 +904,30 @@ class Identifier: public PrimaryExpression
{ {
public: public:
Identifier(Location const& _location, ASTPointer<ASTString> const& _name): Identifier(Location const& _location, ASTPointer<ASTString> const& _name):
PrimaryExpression(_location), m_name(_name), m_referencedDeclaration(nullptr) {} PrimaryExpression(_location), m_name(_name) {}
virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTVisitor& _visitor) override;
virtual void accept(ASTConstVisitor& _visitor) const override; virtual void accept(ASTConstVisitor& _visitor) const override;
virtual void checkTypeRequirements() override; virtual void checkTypeRequirements() override;
ASTString const& getName() const { return *m_name; } ASTString const& getName() const { return *m_name; }
void setReferencedDeclaration(Declaration const& _referencedDeclaration) { m_referencedDeclaration = &_referencedDeclaration; } void setReferencedDeclaration(Declaration const& _referencedDeclaration,
ContractDefinition const* _currentContract = nullptr)
{
m_referencedDeclaration = &_referencedDeclaration;
m_currentContract = _currentContract;
}
Declaration const* getReferencedDeclaration() const { return m_referencedDeclaration; } Declaration const* getReferencedDeclaration() const { return m_referencedDeclaration; }
ContractDefinition const* getCurrentContract() const { return m_currentContract; }
private: private:
ASTPointer<ASTString> m_name; ASTPointer<ASTString> m_name;
/// Declaration the name refers to. /// Declaration the name refers to.
Declaration const* m_referencedDeclaration; Declaration const* m_referencedDeclaration = nullptr;
/// Stores a reference to the current contract. This is needed because types of base contracts
/// change depending on the context.
ContractDefinition const* m_currentContract = nullptr;
}; };
/** /**

1
libsolidity/ASTForward.h

@ -38,6 +38,7 @@ class SourceUnit;
class ImportDirective; class ImportDirective;
class Declaration; class Declaration;
class ContractDefinition; class ContractDefinition;
class InheritanceSpecifier;
class StructDefinition; class StructDefinition;
class ParameterList; class ParameterList;
class FunctionDefinition; class FunctionDefinition;

22
libsolidity/AST_accept.h

@ -61,6 +61,7 @@ void ContractDefinition::accept(ASTVisitor& _visitor)
{ {
if (_visitor.visit(*this)) if (_visitor.visit(*this))
{ {
listAccept(m_baseContracts, _visitor);
listAccept(m_definedStructs, _visitor); listAccept(m_definedStructs, _visitor);
listAccept(m_stateVariables, _visitor); listAccept(m_stateVariables, _visitor);
listAccept(m_definedFunctions, _visitor); listAccept(m_definedFunctions, _visitor);
@ -72,6 +73,7 @@ void ContractDefinition::accept(ASTConstVisitor& _visitor) const
{ {
if (_visitor.visit(*this)) if (_visitor.visit(*this))
{ {
listAccept(m_baseContracts, _visitor);
listAccept(m_definedStructs, _visitor); listAccept(m_definedStructs, _visitor);
listAccept(m_stateVariables, _visitor); listAccept(m_stateVariables, _visitor);
listAccept(m_definedFunctions, _visitor); listAccept(m_definedFunctions, _visitor);
@ -79,6 +81,26 @@ void ContractDefinition::accept(ASTConstVisitor& _visitor) const
_visitor.endVisit(*this); _visitor.endVisit(*this);
} }
void InheritanceSpecifier::accept(ASTVisitor& _visitor)
{
if (_visitor.visit(*this))
{
m_baseName->accept(_visitor);
listAccept(m_arguments, _visitor);
}
_visitor.endVisit(*this);
}
void InheritanceSpecifier::accept(ASTConstVisitor& _visitor) const
{
if (_visitor.visit(*this))
{
m_baseName->accept(_visitor);
listAccept(m_arguments, _visitor);
}
_visitor.endVisit(*this);
}
void StructDefinition::accept(ASTVisitor& _visitor) void StructDefinition::accept(ASTVisitor& _visitor)
{ {
if (_visitor.visit(*this)) if (_visitor.visit(*this))

44
libsolidity/CallGraph.cpp

@ -31,13 +31,9 @@ namespace dev
namespace solidity namespace solidity
{ {
void CallGraph::addFunction(FunctionDefinition const& _function) void CallGraph::addNode(ASTNode const& _node)
{ {
if (!m_functionsSeen.count(&_function)) _node.accept(*this);
{
m_functionsSeen.insert(&_function);
m_workQueue.push(&_function);
}
} }
set<FunctionDefinition const*> const& CallGraph::getCalls() set<FunctionDefinition const*> const& CallGraph::getCalls()
@ -63,5 +59,41 @@ bool CallGraph::visit(Identifier const& _identifier)
return true; return true;
} }
bool CallGraph::visit(FunctionDefinition const& _function)
{
addFunction(_function);
return true;
}
bool CallGraph::visit(MemberAccess const& _memberAccess)
{
// used for "BaseContract.baseContractFunction"
if (_memberAccess.getExpression().getType()->getCategory() == Type::Category::TYPE)
{
TypeType const& type = dynamic_cast<TypeType const&>(*_memberAccess.getExpression().getType());
if (type.getMembers().getMemberType(_memberAccess.getMemberName()))
{
ContractDefinition const& contract = dynamic_cast<ContractType const&>(*type.getActualType())
.getContractDefinition();
for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions())
if (function->getName() == _memberAccess.getMemberName())
{
addFunction(*function);
return true;
}
}
}
return true;
}
void CallGraph::addFunction(FunctionDefinition const& _function)
{
if (!m_functionsSeen.count(&_function))
{
m_functionsSeen.insert(&_function);
m_workQueue.push(&_function);
}
}
} }
} }

7
libsolidity/CallGraph.h

@ -38,14 +38,17 @@ namespace solidity
class CallGraph: private ASTConstVisitor class CallGraph: private ASTConstVisitor
{ {
public: public:
void addFunction(FunctionDefinition const& _function); void addNode(ASTNode const& _node);
void computeCallGraph(); void computeCallGraph();
std::set<FunctionDefinition const*> const& getCalls(); std::set<FunctionDefinition const*> const& getCalls();
private: private:
void addFunctionToQueue(FunctionDefinition const& _function); virtual bool visit(FunctionDefinition const& _function) override;
virtual bool visit(Identifier const& _identifier) override; virtual bool visit(Identifier const& _identifier) override;
virtual bool visit(MemberAccess const& _memberAccess) override;
void addFunction(FunctionDefinition const& _function);
std::set<FunctionDefinition const*> m_functionsSeen; std::set<FunctionDefinition const*> m_functionsSeen;
std::queue<FunctionDefinition const*> m_workQueue; std::queue<FunctionDefinition const*> m_workQueue;

93
libsolidity/Compiler.cpp

@ -21,6 +21,7 @@
*/ */
#include <algorithm> #include <algorithm>
#include <boost/range/adaptor/reversed.hpp>
#include <libevmcore/Instruction.h> #include <libevmcore/Instruction.h>
#include <libevmcore/Assembly.h> #include <libevmcore/Assembly.h>
#include <libsolidity/AST.h> #include <libsolidity/AST.h>
@ -34,48 +35,84 @@ using namespace std;
namespace dev { namespace dev {
namespace solidity { namespace solidity {
void Compiler::compileContract(ContractDefinition const& _contract, vector<MagicVariableDeclaration const*> const& _magicGlobals, void Compiler::compileContract(ContractDefinition const& _contract,
map<ContractDefinition const*, bytes const*> const& _contracts) map<ContractDefinition const*, bytes const*> const& _contracts)
{ {
m_context = CompilerContext(); // clear it just in case m_context = CompilerContext(); // clear it just in case
initializeContext(_contract, _magicGlobals, _contracts); initializeContext(_contract, _contracts);
for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
if (function->getName() != _contract.getName()) // don't add the constructor here for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (function->getName() != contract->getName()) // don't add the constructor here
m_context.addFunction(*function); m_context.addFunction(*function);
appendFunctionSelector(_contract); appendFunctionSelector(_contract);
for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
if (function->getName() != _contract.getName()) // don't add the constructor here for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
if (function->getName() != contract->getName()) // don't add the constructor here
function->accept(*this); function->accept(*this);
// Swap the runtime context with the creation-time context // Swap the runtime context with the creation-time context
swap(m_context, m_runtimeContext); swap(m_context, m_runtimeContext);
initializeContext(_contract, _magicGlobals, _contracts); initializeContext(_contract, _contracts);
packIntoContractCreator(_contract, m_runtimeContext); packIntoContractCreator(_contract, m_runtimeContext);
} }
void Compiler::initializeContext(ContractDefinition const& _contract, vector<MagicVariableDeclaration const*> const& _magicGlobals, void Compiler::initializeContext(ContractDefinition const& _contract,
map<ContractDefinition const*, bytes const*> const& _contracts) map<ContractDefinition const*, bytes const*> const& _contracts)
{ {
m_context.setCompiledContracts(_contracts); m_context.setCompiledContracts(_contracts);
for (MagicVariableDeclaration const* variable: _magicGlobals)
m_context.addMagicGlobal(*variable);
registerStateVariables(_contract); registerStateVariables(_contract);
} }
void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext)
{ {
// arguments for base constructors, filled in derived-to-base order
map<ContractDefinition const*, vector<ASTPointer<Expression>> const*> baseArguments;
set<FunctionDefinition const*> neededFunctions; set<FunctionDefinition const*> neededFunctions;
FunctionDefinition const* constructor = _contract.getConstructor(); set<ASTNode const*> nodesUsedInConstructors;
if (constructor)
neededFunctions = getFunctionsNeededByConstructor(*constructor); // Determine the arguments that are used for the base constructors and also which functions
// are needed at compile time.
std::vector<ContractDefinition const*> const& bases = _contract.getLinearizedBaseContracts();
for (ContractDefinition const* contract: bases)
{
if (FunctionDefinition const* constructor = contract->getConstructor())
nodesUsedInConstructors.insert(constructor);
for (ASTPointer<InheritanceSpecifier> const& base: contract->getBaseContracts())
{
ContractDefinition const* baseContract = dynamic_cast<ContractDefinition const*>(
base->getName()->getReferencedDeclaration());
solAssert(baseContract, "");
if (baseArguments.count(baseContract) == 0)
{
baseArguments[baseContract] = &base->getArguments();
for (ASTPointer<Expression> const& arg: base->getArguments())
nodesUsedInConstructors.insert(arg.get());
}
}
}
//@TODO add virtual functions
neededFunctions = getFunctionsCalled(nodesUsedInConstructors);
for (FunctionDefinition const* fun: neededFunctions) for (FunctionDefinition const* fun: neededFunctions)
m_context.addFunction(*fun); m_context.addFunction(*fun);
if (constructor) // Call constructors in base-to-derived order.
appendConstructorCall(*constructor); // The Constructor for the most derived contract is called later.
for (unsigned i = 1; i < bases.size(); i++)
{
ContractDefinition const* base = bases[bases.size() - i];
solAssert(base, "");
FunctionDefinition const* baseConstructor = base->getConstructor();
if (!baseConstructor)
continue;
solAssert(baseArguments[base], "");
appendBaseConstructorCall(*baseConstructor, *baseArguments[base]);
}
if (_contract.getConstructor())
appendConstructorCall(*_contract.getConstructor());
eth::AssemblyItem sub = m_context.addSubroutine(_runtimeContext.getAssembly()); eth::AssemblyItem sub = m_context.addSubroutine(_runtimeContext.getAssembly());
// stack contains sub size // stack contains sub size
@ -88,6 +125,21 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp
fun->accept(*this); fun->accept(*this);
} }
void Compiler::appendBaseConstructorCall(FunctionDefinition const& _constructor,
vector<ASTPointer<Expression>> const& _arguments)
{
FunctionType constructorType(_constructor);
eth::AssemblyItem returnLabel = m_context.pushNewTag();
for (unsigned i = 0; i < _arguments.size(); ++i)
{
compileExpression(*_arguments[i]);
ExpressionCompiler::appendTypeConversion(m_context, *_arguments[i]->getType(),
*constructorType.getParameterTypes()[i]);
}
m_context.appendJumpTo(m_context.getFunctionEntryLabel(_constructor));
m_context << returnLabel;
}
void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
{ {
eth::AssemblyItem returnTag = m_context.pushNewTag(); eth::AssemblyItem returnTag = m_context.pushNewTag();
@ -107,11 +159,12 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor)
m_context << returnTag; m_context << returnTag;
} }
set<FunctionDefinition const*> Compiler::getFunctionsNeededByConstructor(FunctionDefinition const& _constructor) set<FunctionDefinition const*> Compiler::getFunctionsCalled(set<ASTNode const*> const& _nodes)
{ {
// TODO this does not add virtual functions
CallGraph callgraph; CallGraph callgraph;
callgraph.addFunction(_constructor); for (ASTNode const* node: _nodes)
callgraph.computeCallGraph(); callgraph.addNode(*node);
return callgraph.getCalls(); return callgraph.getCalls();
} }
@ -193,8 +246,8 @@ void Compiler::appendReturnValuePacker(FunctionDefinition const& _function)
void Compiler::registerStateVariables(ContractDefinition const& _contract) void Compiler::registerStateVariables(ContractDefinition const& _contract)
{ {
//@todo sort them? for (ContractDefinition const* contract: boost::adaptors::reverse(_contract.getLinearizedBaseContracts()))
for (ASTPointer<VariableDeclaration> const& variable: _contract.getStateVariables()) for (ASTPointer<VariableDeclaration> const& variable: contract->getStateVariables())
m_context.addStateVariable(*variable); m_context.addStateVariable(*variable);
} }

15
libsolidity/Compiler.h

@ -32,23 +32,24 @@ class Compiler: private ASTConstVisitor
public: public:
explicit Compiler(bool _optimize = false): m_optimize(_optimize), m_context(), m_returnTag(m_context.newTag()) {} explicit Compiler(bool _optimize = false): m_optimize(_optimize), m_context(), m_returnTag(m_context.newTag()) {}
void compileContract(ContractDefinition const& _contract, std::vector<MagicVariableDeclaration const*> const& _magicGlobals, void compileContract(ContractDefinition const& _contract,
std::map<ContractDefinition const*, bytes const*> const& _contracts); std::map<ContractDefinition const*, bytes const*> const& _contracts);
bytes getAssembledBytecode() { return m_context.getAssembledBytecode(m_optimize); } bytes getAssembledBytecode() { return m_context.getAssembledBytecode(m_optimize); }
bytes getRuntimeBytecode() { return m_runtimeContext.getAssembledBytecode(m_optimize);} bytes getRuntimeBytecode() { return m_runtimeContext.getAssembledBytecode(m_optimize);}
void streamAssembly(std::ostream& _stream) const { m_context.streamAssembly(_stream); } void streamAssembly(std::ostream& _stream) const { m_context.streamAssembly(_stream); }
private: private:
/// Registers the global objects and the non-function objects inside the contract with the context. /// Registers the non-function objects inside the contract with the context.
void initializeContext(ContractDefinition const& _contract, std::vector<MagicVariableDeclaration const*> const& _magicGlobals, void initializeContext(ContractDefinition const& _contract,
std::map<ContractDefinition const*, bytes const*> const& _contracts); std::map<ContractDefinition const*, bytes const*> const& _contracts);
/// Adds the code that is run at creation time. Should be run after exchanging the run-time context /// Adds the code that is run at creation time. Should be run after exchanging the run-time context
/// with a new and initialized context. /// with a new and initialized context. Adds the constructor code.
/// adds the constructor code.
void packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext); void packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext);
void appendBaseConstructorCall(FunctionDefinition const& _constructor,
std::vector<ASTPointer<Expression>> const& _arguments);
void appendConstructorCall(FunctionDefinition const& _constructor); void appendConstructorCall(FunctionDefinition const& _constructor);
/// Recursively searches the call graph and returns all functions needed by the constructor (including itself). /// Recursively searches the call graph and returns all functions referenced inside _nodes.
std::set<FunctionDefinition const*> getFunctionsNeededByConstructor(FunctionDefinition const& _constructor); std::set<FunctionDefinition const*> getFunctionsCalled(std::set<ASTNode const*> const& _nodes);
void appendFunctionSelector(ContractDefinition const& _contract); void appendFunctionSelector(ContractDefinition const& _contract);
/// Creates code that unpacks the arguments for the given function, from memory if /// Creates code that unpacks the arguments for the given function, from memory if
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. /// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.

11
libsolidity/CompilerContext.cpp

@ -61,7 +61,9 @@ void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _decla
void CompilerContext::addFunction(FunctionDefinition const& _function) void CompilerContext::addFunction(FunctionDefinition const& _function)
{ {
m_functionEntryLabels.insert(std::make_pair(&_function, m_asm.newTag())); eth::AssemblyItem tag(m_asm.newTag());
m_functionEntryLabels.insert(make_pair(&_function, tag));
m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag));
} }
bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const
@ -83,6 +85,13 @@ eth::AssemblyItem CompilerContext::getFunctionEntryLabel(FunctionDefinition cons
return res->second.tag(); return res->second.tag();
} }
eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const
{
auto res = m_virtualFunctionEntryLabels.find(_function.getName());
solAssert(res != m_virtualFunctionEntryLabels.end(), "Function entry label not found.");
return res->second.tag();
}
unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const
{ {
auto res = m_localVariables.find(&_declaration); auto res = m_localVariables.find(&_declaration);

4
libsolidity/CompilerContext.h

@ -57,6 +57,8 @@ public:
bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; } bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; }
eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const; eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const;
/// @returns the entry label of the given function and takes overrides into account.
eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const;
/// Returns the distance of the given local variable from the top of the local variable stack. /// Returns the distance of the given local variable from the top of the local variable stack.
unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const;
/// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns /// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns
@ -116,6 +118,8 @@ private:
unsigned m_localVariablesSize; unsigned m_localVariablesSize;
/// Labels pointing to the entry points of funcitons. /// Labels pointing to the entry points of funcitons.
std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels; std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels;
/// Labels pointing to the entry points of function overrides.
std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels;
}; };
} }

4
libsolidity/CompilerStack.cpp

@ -113,10 +113,8 @@ void CompilerStack::compile(bool _optimize)
for (ASTPointer<ASTNode> const& node: source->ast->getNodes()) for (ASTPointer<ASTNode> const& node: source->ast->getNodes())
if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get())) if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get()))
{ {
m_globalContext->setCurrentContract(*contract);
shared_ptr<Compiler> compiler = make_shared<Compiler>(_optimize); shared_ptr<Compiler> compiler = make_shared<Compiler>(_optimize);
compiler->compileContract(*contract, m_globalContext->getMagicVariables(), compiler->compileContract(*contract, contractBytecode);
contractBytecode);
Contract& compiledContract = m_contracts[contract->getName()]; Contract& compiledContract = m_contracts[contract->getName()];
compiledContract.bytecode = compiler->getAssembledBytecode(); compiledContract.bytecode = compiler->getAssembledBytecode();
compiledContract.runtimeBytecode = compiler->getRuntimeBytecode(); compiledContract.runtimeBytecode = compiler->getRuntimeBytecode();

5
libsolidity/DeclarationContainer.h

@ -42,11 +42,12 @@ public:
explicit DeclarationContainer(Declaration const* _enclosingDeclaration = nullptr, explicit DeclarationContainer(Declaration const* _enclosingDeclaration = nullptr,
DeclarationContainer const* _enclosingContainer = nullptr): DeclarationContainer const* _enclosingContainer = nullptr):
m_enclosingDeclaration(_enclosingDeclaration), m_enclosingContainer(_enclosingContainer) {} m_enclosingDeclaration(_enclosingDeclaration), m_enclosingContainer(_enclosingContainer) {}
/// Registers the declaration in the scope unless its name is already declared. Returns true iff /// Registers the declaration in the scope unless its name is already declared.
/// it was not yet declared. /// @returns true iff it was not yet declared.
bool registerDeclaration(Declaration const& _declaration, bool _update = false); bool registerDeclaration(Declaration const& _declaration, bool _update = false);
Declaration const* resolveName(ASTString const& _name, bool _recursive = false) const; Declaration const* resolveName(ASTString const& _name, bool _recursive = false) const;
Declaration const* getEnclosingDeclaration() const { return m_enclosingDeclaration; } Declaration const* getEnclosingDeclaration() const { return m_enclosingDeclaration; }
std::map<ASTString, Declaration const*> const& getDeclarations() const { return m_declarations; }
private: private:
Declaration const* m_enclosingDeclaration; Declaration const* m_enclosingDeclaration;

34
libsolidity/ExpressionCompiler.cpp

@ -419,6 +419,22 @@ void ExpressionCompiler::endVisit(MemberAccess const& _memberAccess)
m_currentLValue.retrieveValueIfLValueNotRequested(_memberAccess); m_currentLValue.retrieveValueIfLValueNotRequested(_memberAccess);
break; break;
} }
case Type::Category::TYPE:
{
TypeType const& type = dynamic_cast<TypeType const&>(*_memberAccess.getExpression().getType());
if (type.getMembers().getMemberType(member))
{
ContractDefinition const& contract = dynamic_cast<ContractType const&>(*type.getActualType())
.getContractDefinition();
for (ASTPointer<FunctionDefinition> const& function: contract.getDefinedFunctions())
if (function->getName() == member)
{
m_context << m_context.getFunctionEntryLabel(*function).pushTag();
return;
}
}
BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Invalid member access to " + type.toString()));
}
default: default:
BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Member access to unknown type.")); BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Member access to unknown type."));
} }
@ -449,20 +465,22 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier)
{ {
if (magicVar->getType()->getCategory() == Type::Category::CONTRACT) // must be "this" if (magicVar->getType()->getCategory() == Type::Category::CONTRACT) // must be "this"
m_context << eth::Instruction::ADDRESS; m_context << eth::Instruction::ADDRESS;
return;
} }
if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration)) else if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration))
{ m_context << m_context.getVirtualFunctionEntryLabel(*functionDef).pushTag();
m_context << m_context.getFunctionEntryLabel(*functionDef).pushTag(); else if (dynamic_cast<VariableDeclaration const*>(declaration))
return;
}
if (dynamic_cast<VariableDeclaration const*>(declaration))
{ {
m_currentLValue.fromIdentifier(_identifier, *declaration); m_currentLValue.fromIdentifier(_identifier, *declaration);
m_currentLValue.retrieveValueIfLValueNotRequested(_identifier); m_currentLValue.retrieveValueIfLValueNotRequested(_identifier);
return;
} }
else if (dynamic_cast<ContractDefinition const*>(declaration))
{
// no-op
}
else
{
BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Identifier type not expected in expression context.")); BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Identifier type not expected in expression context."));
}
} }
void ExpressionCompiler::endVisit(Literal const& _literal) void ExpressionCompiler::endVisit(Literal const& _literal)

12
libsolidity/GlobalContext.cpp

@ -68,7 +68,7 @@ void GlobalContext::setCurrentContract(ContractDefinition const& _contract)
vector<Declaration const*> GlobalContext::getDeclarations() const vector<Declaration const*> GlobalContext::getDeclarations() const
{ {
vector<Declaration const*> declarations; vector<Declaration const*> declarations;
declarations.reserve(m_magicVariables.size() + 1); declarations.reserve(m_magicVariables.size());
for (ASTPointer<Declaration const> const& variable: m_magicVariables) for (ASTPointer<Declaration const> const& variable: m_magicVariables)
declarations.push_back(variable.get()); declarations.push_back(variable.get());
return declarations; return declarations;
@ -83,15 +83,5 @@ MagicVariableDeclaration const* GlobalContext::getCurrentThis() const
} }
vector<MagicVariableDeclaration const*> GlobalContext::getMagicVariables() const
{
vector<MagicVariableDeclaration const*> declarations;
declarations.reserve(m_magicVariables.size() + 1);
for (ASTPointer<MagicVariableDeclaration const> const& variable: m_magicVariables)
declarations.push_back(variable.get());
declarations.push_back(getCurrentThis());
return declarations;
}
} }
} }

2
libsolidity/GlobalContext.h

@ -49,8 +49,6 @@ public:
void setCurrentContract(ContractDefinition const& _contract); void setCurrentContract(ContractDefinition const& _contract);
MagicVariableDeclaration const* getCurrentThis() const; MagicVariableDeclaration const* getCurrentThis() const;
/// @returns all magic variables.
std::vector<MagicVariableDeclaration const*> getMagicVariables() const;
/// @returns a vector of all implicit global declarations excluding "this". /// @returns a vector of all implicit global declarations excluding "this".
std::vector<Declaration const*> getDeclarations() const; std::vector<Declaration const*> getDeclarations() const;

116
libsolidity/NameAndTypeResolver.cpp

@ -31,7 +31,6 @@ namespace dev
namespace solidity namespace solidity
{ {
NameAndTypeResolver::NameAndTypeResolver(std::vector<Declaration const*> const& _globals) NameAndTypeResolver::NameAndTypeResolver(std::vector<Declaration const*> const& _globals)
{ {
for (Declaration const* declaration: _globals) for (Declaration const* declaration: _globals)
@ -46,18 +45,27 @@ void NameAndTypeResolver::registerDeclarations(SourceUnit& _sourceUnit)
void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract) void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract)
{ {
m_currentScope = &m_scopes[nullptr];
for (ASTPointer<InheritanceSpecifier> const& baseContract: _contract.getBaseContracts())
ReferencesResolver resolver(*baseContract, *this, &_contract, nullptr);
m_currentScope = &m_scopes[&_contract]; m_currentScope = &m_scopes[&_contract];
linearizeBaseContracts(_contract);
for (ContractDefinition const* base: _contract.getLinearizedBaseContracts())
importInheritedScope(*base);
for (ASTPointer<StructDefinition> const& structDef: _contract.getDefinedStructs()) for (ASTPointer<StructDefinition> const& structDef: _contract.getDefinedStructs())
ReferencesResolver resolver(*structDef, *this, nullptr); ReferencesResolver resolver(*structDef, *this, &_contract, nullptr);
for (ASTPointer<VariableDeclaration> const& variable: _contract.getStateVariables()) for (ASTPointer<VariableDeclaration> const& variable: _contract.getStateVariables())
ReferencesResolver resolver(*variable, *this, nullptr); ReferencesResolver resolver(*variable, *this, &_contract, nullptr);
for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions())
{ {
m_currentScope = &m_scopes[function.get()]; m_currentScope = &m_scopes[function.get()];
ReferencesResolver referencesResolver(*function, *this, ReferencesResolver referencesResolver(*function, *this, &_contract,
function->getReturnParameterList().get()); function->getReturnParameterList().get());
} }
m_currentScope = &m_scopes[nullptr];
} }
void NameAndTypeResolver::checkTypeRequirements(ContractDefinition& _contract) void NameAndTypeResolver::checkTypeRequirements(ContractDefinition& _contract)
@ -86,6 +94,96 @@ Declaration const* NameAndTypeResolver::getNameFromCurrentScope(ASTString const&
return m_currentScope->resolveName(_name, _recursive); return m_currentScope->resolveName(_name, _recursive);
} }
void NameAndTypeResolver::importInheritedScope(ContractDefinition const& _base)
{
auto iterator = m_scopes.find(&_base);
solAssert(iterator != end(m_scopes), "");
for (auto const& nameAndDeclaration: iterator->second.getDeclarations())
{
Declaration const* declaration = nameAndDeclaration.second;
// Import if it was declared in the base and is not the constructor
if (declaration->getScope() == &_base && declaration->getName() != _base.getName())
m_currentScope->registerDeclaration(*declaration);
}
}
void NameAndTypeResolver::linearizeBaseContracts(ContractDefinition& _contract) const
{
// order in the lists is from derived to base
// list of lists to linearize, the last element is the list of direct bases
list<list<ContractDefinition const*>> input(1, {&_contract});
for (ASTPointer<InheritanceSpecifier> const& baseSpecifier: _contract.getBaseContracts())
{
ASTPointer<Identifier> baseName = baseSpecifier->getName();
ContractDefinition const* base = dynamic_cast<ContractDefinition const*>(
baseName->getReferencedDeclaration());
if (!base)
BOOST_THROW_EXCEPTION(baseName->createTypeError("Contract expected."));
// "push_back" has the effect that bases mentioned earlier can overwrite members of bases
// mentioned later
input.back().push_back(base);
vector<ContractDefinition const*> const& basesBases = base->getLinearizedBaseContracts();
if (basesBases.empty())
BOOST_THROW_EXCEPTION(baseName->createTypeError("Definition of base has to precede definition of derived contract"));
input.push_front(list<ContractDefinition const*>(basesBases.begin(), basesBases.end()));
}
vector<ContractDefinition const*> result = cThreeMerge(input);
if (result.empty())
BOOST_THROW_EXCEPTION(_contract.createTypeError("Linearization of inheritance graph impossible"));
_contract.setLinearizedBaseContracts(result);
}
template <class _T>
vector<_T const*> NameAndTypeResolver::cThreeMerge(list<list<_T const*>>& _toMerge)
{
// returns true iff _candidate appears only as last element of the lists
auto appearsOnlyAtHead = [&](_T const* _candidate) -> bool
{
for (list<_T const*> const& bases: _toMerge)
{
solAssert(!bases.empty(), "");
if (find(++bases.begin(), bases.end(), _candidate) != bases.end())
return false;
}
return true;
};
// returns the next candidate to append to the linearized list or nullptr on failure
auto nextCandidate = [&]() -> _T const*
{
for (list<_T const*> const& bases: _toMerge)
{
solAssert(!bases.empty(), "");
if (appearsOnlyAtHead(bases.front()))
return bases.front();
}
return nullptr;
};
// removes the given contract from all lists
auto removeCandidate = [&](_T const* _candidate)
{
for (auto it = _toMerge.begin(); it != _toMerge.end();)
{
it->remove(_candidate);
if (it->empty())
it = _toMerge.erase(it);
else
++it;
}
};
_toMerge.remove_if([](list<_T const*> const& _bases) { return _bases.empty(); });
vector<_T const*> result;
while (!_toMerge.empty())
{
_T const* candidate = nextCandidate();
if (!candidate)
return vector<_T const*>();
result.push_back(candidate);
removeCandidate(candidate);
}
return result;
}
DeclarationRegistrationHelper::DeclarationRegistrationHelper(map<ASTNode const*, DeclarationContainer>& _scopes, DeclarationRegistrationHelper::DeclarationRegistrationHelper(map<ASTNode const*, DeclarationContainer>& _scopes,
ASTNode& _astRoot): ASTNode& _astRoot):
m_scopes(_scopes), m_currentScope(nullptr) m_scopes(_scopes), m_currentScope(nullptr)
@ -169,8 +267,10 @@ void DeclarationRegistrationHelper::registerDeclaration(Declaration& _declaratio
} }
ReferencesResolver::ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, ReferencesResolver::ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver,
ParameterList* _returnParameters, bool _allowLazyTypes): ContractDefinition const* _currentContract,
m_resolver(_resolver), m_returnParameters(_returnParameters), m_allowLazyTypes(_allowLazyTypes) ParameterList const* _returnParameters, bool _allowLazyTypes):
m_resolver(_resolver), m_currentContract(_currentContract),
m_returnParameters(_returnParameters), m_allowLazyTypes(_allowLazyTypes)
{ {
_root.accept(*this); _root.accept(*this);
} }
@ -218,7 +318,7 @@ bool ReferencesResolver::visit(Identifier& _identifier)
if (!declaration) if (!declaration)
BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_identifier.getLocation()) BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_identifier.getLocation())
<< errinfo_comment("Undeclared identifier.")); << errinfo_comment("Undeclared identifier."));
_identifier.setReferencedDeclaration(*declaration); _identifier.setReferencedDeclaration(*declaration, m_currentContract);
return false; return false;
} }

19
libsolidity/NameAndTypeResolver.h

@ -23,6 +23,7 @@
#pragma once #pragma once
#include <map> #include <map>
#include <list>
#include <boost/noncopyable.hpp> #include <boost/noncopyable.hpp>
#include <libsolidity/DeclarationContainer.h> #include <libsolidity/DeclarationContainer.h>
@ -64,6 +65,17 @@ public:
private: private:
void reset(); void reset();
/// Imports all members declared directly in the given contract (i.e. does not import inherited
/// members) into the current scope if they are not present already.
void importInheritedScope(ContractDefinition const& _base);
/// Computes "C3-Linearization" of base contracts and stores it inside the contract.
void linearizeBaseContracts(ContractDefinition& _contract) const;
/// Computes the C3-merge of the given list of lists of bases.
/// @returns the linearized vector or an empty vector if linearization is not possible.
template <class _T>
static std::vector<_T const*> cThreeMerge(std::list<std::list<_T const*>>& _toMerge);
/// Maps nodes declaring a scope to scopes, i.e. ContractDefinition and FunctionDeclaration, /// Maps nodes declaring a scope to scopes, i.e. ContractDefinition and FunctionDeclaration,
/// where nullptr denotes the global scope. Note that structs are not scope since they do /// where nullptr denotes the global scope. Note that structs are not scope since they do
/// not contain code. /// not contain code.
@ -108,7 +120,9 @@ class ReferencesResolver: private ASTVisitor
{ {
public: public:
ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver,
ParameterList* _returnParameters, bool _allowLazyTypes = true); ContractDefinition const* _currentContract,
ParameterList const* _returnParameters,
bool _allowLazyTypes = true);
private: private:
virtual void endVisit(VariableDeclaration& _variable) override; virtual void endVisit(VariableDeclaration& _variable) override;
@ -118,7 +132,8 @@ private:
virtual bool visit(Return& _return) override; virtual bool visit(Return& _return) override;
NameAndTypeResolver& m_resolver; NameAndTypeResolver& m_resolver;
ParameterList* m_returnParameters; ContractDefinition const* m_currentContract;
ParameterList const* m_returnParameters;
bool m_allowLazyTypes; bool m_allowLazyTypes;
}; };

30
libsolidity/Parser.cpp

@ -117,10 +117,18 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral()); docstring = make_shared<ASTString>(m_scanner->getCurrentCommentLiteral());
expectToken(Token::CONTRACT); expectToken(Token::CONTRACT);
ASTPointer<ASTString> name = expectIdentifierToken(); ASTPointer<ASTString> name = expectIdentifierToken();
expectToken(Token::LBRACE); vector<ASTPointer<InheritanceSpecifier>> baseContracts;
vector<ASTPointer<StructDefinition>> structs; vector<ASTPointer<StructDefinition>> structs;
vector<ASTPointer<VariableDeclaration>> stateVariables; vector<ASTPointer<VariableDeclaration>> stateVariables;
vector<ASTPointer<FunctionDefinition>> functions; vector<ASTPointer<FunctionDefinition>> functions;
if (m_scanner->getCurrentToken() == Token::IS)
do
{
m_scanner->next();
baseContracts.push_back(parseInheritanceSpecifier());
}
while (m_scanner->getCurrentToken() == Token::COMMA);
expectToken(Token::LBRACE);
bool visibilityIsPublic = true; bool visibilityIsPublic = true;
while (true) while (true)
{ {
@ -149,7 +157,25 @@ ASTPointer<ContractDefinition> Parser::parseContractDefinition()
} }
nodeFactory.markEndPosition(); nodeFactory.markEndPosition();
expectToken(Token::RBRACE); expectToken(Token::RBRACE);
return nodeFactory.createNode<ContractDefinition>(name, docstring, structs, stateVariables, functions); return nodeFactory.createNode<ContractDefinition>(name, docstring, baseContracts, structs,
stateVariables, functions);
}
ASTPointer<InheritanceSpecifier> Parser::parseInheritanceSpecifier()
{
ASTNodeFactory nodeFactory(*this);
ASTPointer<Identifier> name = ASTNodeFactory(*this).createNode<Identifier>(expectIdentifierToken());
vector<ASTPointer<Expression>> arguments;
if (m_scanner->getCurrentToken() == Token::LPAREN)
{
m_scanner->next();
arguments = parseFunctionCallArguments();
nodeFactory.markEndPosition();
expectToken(Token::RPAREN);
}
else
nodeFactory.setEndPositionFromNode(name);
return nodeFactory.createNode<InheritanceSpecifier>(name, arguments);
} }
ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic) ASTPointer<FunctionDefinition> Parser::parseFunctionDefinition(bool _isPublic)

1
libsolidity/Parser.h

@ -49,6 +49,7 @@ private:
///@name Parsing functions for the AST nodes ///@name Parsing functions for the AST nodes
ASTPointer<ImportDirective> parseImportDirective(); ASTPointer<ImportDirective> parseImportDirective();
ASTPointer<ContractDefinition> parseContractDefinition(); ASTPointer<ContractDefinition> parseContractDefinition();
ASTPointer<InheritanceSpecifier> parseInheritanceSpecifier();
ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic); ASTPointer<FunctionDefinition> parseFunctionDefinition(bool _isPublic);
ASTPointer<StructDefinition> parseStructDefinition(); ASTPointer<StructDefinition> parseStructDefinition();
ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar); ASTPointer<VariableDeclaration> parseVariableDeclaration(bool _allowVar);

2
libsolidity/Token.h

@ -153,7 +153,7 @@ namespace solidity
K(DEFAULT, "default", 0) \ K(DEFAULT, "default", 0) \
K(DO, "do", 0) \ K(DO, "do", 0) \
K(ELSE, "else", 0) \ K(ELSE, "else", 0) \
K(EXTENDS, "extends", 0) \ K(IS, "is", 0) \
K(FOR, "for", 0) \ K(FOR, "for", 0) \
K(FUNCTION, "function", 0) \ K(FUNCTION, "function", 0) \
K(IF, "if", 0) \ K(IF, "if", 0) \

32
libsolidity/Types.cpp

@ -431,12 +431,19 @@ bool ContractType::isImplicitlyConvertibleTo(Type const& _convertTo) const
return true; return true;
if (_convertTo.getCategory() == Category::INTEGER) if (_convertTo.getCategory() == Category::INTEGER)
return dynamic_cast<IntegerType const&>(_convertTo).isAddress(); return dynamic_cast<IntegerType const&>(_convertTo).isAddress();
if (_convertTo.getCategory() == Category::CONTRACT)
{
auto const& bases = getContractDefinition().getLinearizedBaseContracts();
return find(bases.begin(), bases.end(),
&dynamic_cast<ContractType const&>(_convertTo).getContractDefinition()) != bases.end();
}
return false; return false;
} }
bool ContractType::isExplicitlyConvertibleTo(Type const& _convertTo) const bool ContractType::isExplicitlyConvertibleTo(Type const& _convertTo) const
{ {
return isImplicitlyConvertibleTo(_convertTo) || _convertTo.getCategory() == Category::INTEGER; return isImplicitlyConvertibleTo(_convertTo) || _convertTo.getCategory() == Category::INTEGER ||
_convertTo.getCategory() == Category::CONTRACT;
} }
TypePointer ContractType::unaryOperatorResult(Token::Value _operator) const TypePointer ContractType::unaryOperatorResult(Token::Value _operator) const
@ -695,6 +702,29 @@ bool TypeType::operator==(Type const& _other) const
return *getActualType() == *other.getActualType(); return *getActualType() == *other.getActualType();
} }
MemberList const& TypeType::getMembers() const
{
// We need to lazy-initialize it because of recursive references.
if (!m_members)
{
map<string, TypePointer> members;
if (m_actualType->getCategory() == Category::CONTRACT && m_currentContract != nullptr)
{
ContractDefinition const& contract = dynamic_cast<ContractType const&>(*m_actualType).getContractDefinition();
vector<ContractDefinition const*> currentBases = m_currentContract->getLinearizedBaseContracts();
if (find(currentBases.begin(), currentBases.end(), &contract) != currentBases.end())
// We are accessing the type of a base contract, so add all public and private
// functions. Note that this does not add inherited functions on purpose.
for (ASTPointer<FunctionDefinition> const& f: contract.getDefinedFunctions())
if (f->getName() != contract.getName())
members[f->getName()] = make_shared<FunctionType>(*f);
}
m_members.reset(new MemberList(members));
}
return *m_members;
}
MagicType::MagicType(MagicType::Kind _kind): MagicType::MagicType(MagicType::Kind _kind):
m_kind(_kind) m_kind(_kind)
{ {

8
libsolidity/Types.h

@ -442,7 +442,8 @@ class TypeType: public Type
{ {
public: public:
virtual Category getCategory() const override { return Category::TYPE; } virtual Category getCategory() const override { return Category::TYPE; }
TypeType(TypePointer const& _actualType): m_actualType(_actualType) {} TypeType(TypePointer const& _actualType, ContractDefinition const* _currentContract = nullptr):
m_actualType(_actualType), m_currentContract(_currentContract) {}
TypePointer const& getActualType() const { return m_actualType; } TypePointer const& getActualType() const { return m_actualType; }
virtual TypePointer binaryOperatorResult(Token::Value, TypePointer const&) const override { return TypePointer(); } virtual TypePointer binaryOperatorResult(Token::Value, TypePointer const&) const override { return TypePointer(); }
@ -451,9 +452,14 @@ public:
virtual u256 getStorageSize() const override { BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Storage size of non-storable type type requested.")); } virtual u256 getStorageSize() const override { BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Storage size of non-storable type type requested.")); }
virtual bool canLiveOutsideStorage() const override { return false; } virtual bool canLiveOutsideStorage() const override { return false; }
virtual std::string toString() const override { return "type(" + m_actualType->toString() + ")"; } virtual std::string toString() const override { return "type(" + m_actualType->toString() + ")"; }
virtual MemberList const& getMembers() const override;
private: private:
TypePointer m_actualType; TypePointer m_actualType;
/// Context in which this type is used (influences visibility etc.), can be nullptr.
ContractDefinition const* m_currentContract;
/// List of member types, will be lazy-initialized because of recursive references.
mutable std::unique_ptr<MemberList> m_members;
}; };

5
libsolidity/grammar.txt

@ -1,7 +1,10 @@
ContractDefinition = 'contract' Identifier '{' ContractPart* '}' ContractDefinition = 'contract' Identifier
( 'is' InheritanceSpecifier (',' InheritanceSpecifier )* )?
'{' ContractPart* '}'
ContractPart = VariableDeclaration ';' | StructDefinition | ContractPart = VariableDeclaration ';' | StructDefinition |
FunctionDefinition | 'public:' | 'private:' FunctionDefinition | 'public:' | 'private:'
InheritanceSpecifier = Identifier ( '(' Expression ( ',' Expression )* ')' )?
StructDefinition = 'struct' Identifier '{' StructDefinition = 'struct' Identifier '{'
( VariableDeclaration (';' VariableDeclaration)* )? '} ( VariableDeclaration (';' VariableDeclaration)* )? '}

2
test/SolidityCompiler.cpp

@ -64,7 +64,7 @@ bytes compileContract(const string& _sourceCode)
if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get())) if (ContractDefinition* contract = dynamic_cast<ContractDefinition*>(node.get()))
{ {
Compiler compiler; Compiler compiler;
compiler.compileContract(*contract, {}, map<ContractDefinition const*, bytes const*>{}); compiler.compileContract(*contract, map<ContractDefinition const*, bytes const*>{});
// debug // debug
//compiler.streamAssembly(cout); //compiler.streamAssembly(cout);

135
test/SolidityEndToEndTest.cpp

@ -1493,6 +1493,141 @@ BOOST_AUTO_TEST_CASE(value_for_constructor)
BOOST_REQUIRE(callContractFunction("getBalances()") == encodeArgs(12, 10)); BOOST_REQUIRE(callContractFunction("getBalances()") == encodeArgs(12, 10));
} }
BOOST_AUTO_TEST_CASE(virtual_function_calls)
{
char const* sourceCode = R"(
contract Base {
function f() returns (uint i) { return g(); }
function g() returns (uint i) { return 1; }
}
contract Derived is Base {
function g() returns (uint i) { return 2; }
}
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("g()") == encodeArgs(2));
BOOST_CHECK(callContractFunction("f()") == encodeArgs(2));
}
BOOST_AUTO_TEST_CASE(access_base_storage)
{
char const* sourceCode = R"(
contract Base {
uint dataBase;
function getViaBase() returns (uint i) { return dataBase; }
}
contract Derived is Base {
uint dataDerived;
function setData(uint base, uint derived) returns (bool r) {
dataBase = base;
dataDerived = derived;
return true;
}
function getViaDerived() returns (uint base, uint derived) {
base = dataBase;
derived = dataDerived;
}
}
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("setData(uint256,uint256)", 1, 2) == encodeArgs(true));
BOOST_CHECK(callContractFunction("getViaBase()") == encodeArgs(1));
BOOST_CHECK(callContractFunction("getViaDerived()") == encodeArgs(1, 2));
}
BOOST_AUTO_TEST_CASE(single_copy_with_multiple_inheritance)
{
char const* sourceCode = R"(
contract Base {
uint data;
function setData(uint i) { data = i; }
function getViaBase() returns (uint i) { return data; }
}
contract A is Base { function setViaA(uint i) { setData(i); } }
contract B is Base { function getViaB() returns (uint i) { return getViaBase(); } }
contract Derived is A, B, Base { }
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("getViaB()") == encodeArgs(0));
BOOST_CHECK(callContractFunction("setViaA(uint256)", 23) == encodeArgs());
BOOST_CHECK(callContractFunction("getViaB()") == encodeArgs(23));
}
BOOST_AUTO_TEST_CASE(explicit_base_cass)
{
char const* sourceCode = R"(
contract BaseBase { function g() returns (uint r) { return 1; } }
contract Base is BaseBase { function g() returns (uint r) { return 2; } }
contract Derived is Base {
function f() returns (uint r) { return BaseBase.g(); }
function g() returns (uint r) { return 3; }
}
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("g()") == encodeArgs(3));
BOOST_CHECK(callContractFunction("f()") == encodeArgs(1));
}
BOOST_AUTO_TEST_CASE(base_constructor_arguments)
{
char const* sourceCode = R"(
contract BaseBase {
uint m_a;
function BaseBase(uint a) {
m_a = a;
}
}
contract Base is BaseBase(7) {
function Base() {
m_a *= m_a;
}
}
contract Derived is Base() {
function getA() returns (uint r) { return m_a; }
}
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("getA()") == encodeArgs(7 * 7));
}
BOOST_AUTO_TEST_CASE(function_usage_in_constructor_arguments)
{
char const* sourceCode = R"(
contract BaseBase {
uint m_a;
function BaseBase(uint a) {
m_a = a;
}
function g() returns (uint r) { return 2; }
}
contract Base is BaseBase(BaseBase.g()) {
}
contract Derived is Base() {
function getA() returns (uint r) { return m_a; }
}
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("getA()") == encodeArgs(2));
}
BOOST_AUTO_TEST_CASE(constructor_argument_overriding)
{
char const* sourceCode = R"(
contract BaseBase {
uint m_a;
function BaseBase(uint a) {
m_a = a;
}
}
contract Base is BaseBase(2) { }
contract Derived is Base, BaseBase(3) {
function getA() returns (uint r) { return m_a; }
}
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("getA()") == encodeArgs(3));
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
} }

123
test/SolidityNameAndTypeResolution.cpp

@ -357,7 +357,6 @@ BOOST_AUTO_TEST_CASE(function_canonical_signature_type_aliases)
} }
} }
BOOST_AUTO_TEST_CASE(hash_collision_in_interface) BOOST_AUTO_TEST_CASE(hash_collision_in_interface)
{ {
char const* text = "contract test {\n" char const* text = "contract test {\n"
@ -369,6 +368,128 @@ BOOST_AUTO_TEST_CASE(hash_collision_in_interface)
BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError); BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError);
} }
BOOST_AUTO_TEST_CASE(inheritance_basic)
{
char const* text = R"(
contract base { uint baseMember; struct BaseType { uint element; } }
contract derived is base {
BaseType data;
function f() { baseMember = 7; }
}
)";
BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text));
}
BOOST_AUTO_TEST_CASE(inheritance_diamond_basic)
{
char const* text = R"(
contract root { function rootFunction() {} }
contract inter1 is root { function f() {} }
contract inter2 is root { function f() {} }
contract derived is inter1, inter2, root {
function g() { f(); rootFunction(); }
}
)";
BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text));
}
BOOST_AUTO_TEST_CASE(cyclic_inheritance)
{
char const* text = R"(
contract A is B { }
contract B is A { }
)";
BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError);
}
BOOST_AUTO_TEST_CASE(illegal_override_direct)
{
char const* text = R"(
contract B { function f() {} }
contract C is B { function f(uint i) {} }
)";
BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError);
}
BOOST_AUTO_TEST_CASE(illegal_override_indirect)
{
char const* text = R"(
contract A { function f(uint a) {} }
contract B { function f() {} }
contract C is A, B { }
)";
BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError);
}
BOOST_AUTO_TEST_CASE(complex_inheritance)
{
char const* text = R"(
contract A { function f() { uint8 x = C(0).g(); } }
contract B { function f() {} function g() returns (uint8 r) {} }
contract C is A, B { }
)";
BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text));
}
BOOST_AUTO_TEST_CASE(constructor_visibility)
{
// The constructor of a base class should not be visible in the derived class
char const* text = R"(
contract A { function A() { } }
contract B is A { function f() { A x = A(0); } }
)";
BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text));
}
BOOST_AUTO_TEST_CASE(overriding_constructor)
{
// It is fine to "override" constructor of a base class since it is invisible
char const* text = R"(
contract A { function A() { } }
contract B is A { function A() returns (uint8 r) {} }
)";
BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text));
}
BOOST_AUTO_TEST_CASE(missing_base_constructor_arguments)
{
char const* text = R"(
contract A { function A(uint a) { } }
contract B is A { }
)";
BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError);
}
BOOST_AUTO_TEST_CASE(base_constructor_arguments_override)
{
char const* text = R"(
contract A { function A(uint a) { } }
contract B is A { }
)";
BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError);
}
BOOST_AUTO_TEST_CASE(implicit_derived_to_base_conversion)
{
char const* text = R"(
contract A { }
contract B is A {
function f() { A a = B(1); }
}
)";
BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text));
}
BOOST_AUTO_TEST_CASE(implicit_base_to_derived_conversion)
{
char const* text = R"(
contract A { }
contract B is A {
function f() { B b = A(1); }
}
)";
BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError);
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
} }

45
test/SolidityParser.cpp

@ -495,6 +495,51 @@ BOOST_AUTO_TEST_CASE(multiple_contracts_and_imports)
BOOST_CHECK_NO_THROW(parseText(text)); BOOST_CHECK_NO_THROW(parseText(text));
} }
BOOST_AUTO_TEST_CASE(contract_inheritance)
{
char const* text = "contract base {\n"
" function fun() {\n"
" uint64(2);\n"
" }\n"
"}\n"
"contract derived is base {\n"
" function fun() {\n"
" uint64(2);\n"
" }\n"
"}\n";
BOOST_CHECK_NO_THROW(parseText(text));
}
BOOST_AUTO_TEST_CASE(contract_multiple_inheritance)
{
char const* text = "contract base {\n"
" function fun() {\n"
" uint64(2);\n"
" }\n"
"}\n"
"contract derived is base, nonExisting {\n"
" function fun() {\n"
" uint64(2);\n"
" }\n"
"}\n";
BOOST_CHECK_NO_THROW(parseText(text));
}
BOOST_AUTO_TEST_CASE(contract_multiple_inheritance_with_arguments)
{
char const* text = "contract base {\n"
" function fun() {\n"
" uint64(2);\n"
" }\n"
"}\n"
"contract derived is base(2), nonExisting(\"abc\", \"def\", base.fun()) {\n"
" function fun() {\n"
" uint64(2);\n"
" }\n"
"}\n";
BOOST_CHECK_NO_THROW(parseText(text));
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
} }

Loading…
Cancel
Save