diff --git a/libsolidity/AST.cpp b/libsolidity/AST.cpp index 7b6335645..2cb738d30 100644 --- a/libsolidity/AST.cpp +++ b/libsolidity/AST.cpp @@ -43,6 +43,11 @@ TypeError ASTNode::createTypeError(string const& _description) const void ContractDefinition::checkTypeRequirements() { + for (ASTPointer const& base: getBaseContracts()) + base->checkTypeRequirements(); + + checkIllegalOverrides(); + FunctionDefinition const* constructor = getConstructor(); if (constructor && !constructor->getReturnParameters().empty()) BOOST_THROW_EXCEPTION(constructor->getReturnParameterList()->createTypeError( @@ -52,7 +57,6 @@ void ContractDefinition::checkTypeRequirements() function->checkTypeRequirements(); // check for hash collisions in function signatures - vector, FunctionDefinition const*>> exportedFunctionList = getInterfaceFunctionList(); set> hashes; for (auto const& hashAndFunction: getInterfaceFunctionList()) { @@ -83,17 +87,59 @@ FunctionDefinition const* ContractDefinition::getConstructor() const return nullptr; } -vector, FunctionDefinition const*>> ContractDefinition::getInterfaceFunctionList() const +void ContractDefinition::checkIllegalOverrides() const { - vector, FunctionDefinition const*>> exportedFunctions; - for (ASTPointer const& f: m_definedFunctions) - if (f->isPublic() && f->getName() != getName()) + map functions; + + // We search from derived to base, so the stored item causes the error. + for (ContractDefinition const* contract: getLinearizedBaseContracts()) + for (ASTPointer const& function: contract->getDefinedFunctions()) { - FixedHash<4> hash(dev::sha3(f->getCanonicalSignature())); - exportedFunctions.push_back(make_pair(hash, f.get())); + if (function->getName() == contract->getName()) + continue; // constructors can neither be overriden nor override anything + FunctionDefinition const*& override = functions[function->getName()]; + if (!override) + override = function.get(); + else if (override->isPublic() != function->isPublic() || + override->isDeclaredConst() != function->isDeclaredConst() || + FunctionType(*override) != FunctionType(*function)) + BOOST_THROW_EXCEPTION(override->createTypeError("Override changes extended function signature.")); } +} - return exportedFunctions; +vector, FunctionDefinition const*>> const& ContractDefinition::getInterfaceFunctionList() const +{ + if (!m_interfaceFunctionList) + { + set functionsSeen; + m_interfaceFunctionList.reset(new vector, FunctionDefinition const*>>()); + for (ContractDefinition const* contract: getLinearizedBaseContracts()) + for (ASTPointer const& f: contract->getDefinedFunctions()) + if (f->isPublic() && f->getName() != contract->getName() && + functionsSeen.count(f->getName()) == 0) + { + functionsSeen.insert(f->getName()); + FixedHash<4> hash(dev::sha3(f->getCanonicalSignature())); + m_interfaceFunctionList->push_back(make_pair(hash, f.get())); + } + } + return *m_interfaceFunctionList; +} + +void InheritanceSpecifier::checkTypeRequirements() +{ + m_baseName->checkTypeRequirements(); + for (ASTPointer const& argument: m_arguments) + argument->checkTypeRequirements(); + + ContractDefinition const* base = dynamic_cast(m_baseName->getReferencedDeclaration()); + solAssert(base, "Base contract not available."); + TypePointers parameterTypes = ContractType(*base).getConstructorType()->getParameterTypes(); + if (parameterTypes.size() != m_arguments.size()) + BOOST_THROW_EXCEPTION(createTypeError("Wrong argument count for constructor call.")); + for (size_t i = 0; i < m_arguments.size(); ++i) + if (!m_arguments[i]->getType()->isImplicitlyConvertibleTo(*parameterTypes[i])) + BOOST_THROW_EXCEPTION(createTypeError("Invalid type for argument in constructer call.")); } void StructDefinition::checkMemberTypes() const @@ -346,7 +392,8 @@ void MemberAccess::checkTypeRequirements() Type const& type = *m_expression->getType(); m_type = type.getMemberType(*m_memberName); if (!m_type) - BOOST_THROW_EXCEPTION(createTypeError("Member \"" + *m_memberName + "\" not found in " + type.toString())); + BOOST_THROW_EXCEPTION(createTypeError("Member \"" + *m_memberName + "\" not found or not " + "visible in " + type.toString())); //@todo later, this will not always be STORAGE m_lvalue = type.getCategory() == Type::Category::STRUCT ? LValueType::STORAGE : LValueType::NONE; } @@ -396,7 +443,7 @@ void Identifier::checkTypeRequirements() ContractDefinition const* contractDef = dynamic_cast(m_referencedDeclaration); if (contractDef) { - m_type = make_shared(make_shared(*contractDef)); + m_type = make_shared(make_shared(*contractDef), m_currentContract); return; } MagicVariableDeclaration const* magicVariable = dynamic_cast(m_referencedDeclaration); diff --git a/libsolidity/AST.h b/libsolidity/AST.h index 409aed443..8079348cf 100755 --- a/libsolidity/AST.h +++ b/libsolidity/AST.h @@ -158,10 +158,12 @@ public: ContractDefinition(Location const& _location, ASTPointer const& _name, ASTPointer const& _documentation, + std::vector> const& _baseContracts, std::vector> const& _definedStructs, std::vector> const& _stateVariables, std::vector> const& _definedFunctions): Declaration(_location, _name), + m_baseContracts(_baseContracts), m_definedStructs(_definedStructs), m_stateVariables(_stateVariables), m_definedFunctions(_definedFunctions), @@ -171,12 +173,13 @@ public: virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; + std::vector> const& getBaseContracts() const { return m_baseContracts; } std::vector> const& getDefinedStructs() const { return m_definedStructs; } std::vector> const& getStateVariables() const { return m_stateVariables; } std::vector> const& getDefinedFunctions() const { return m_definedFunctions; } - /// Checks that the constructor does not have a "returns" statement and calls - /// checkTypeRequirements on all its functions. + /// Checks that there are no illegal overrides, that the constructor does not have a "returns" + /// and calls checkTypeRequirements on all its functions. void checkTypeRequirements(); /// @return A shared pointer of an ASTString. @@ -187,16 +190,47 @@ public: /// as intended for use by the ABI. std::map, FunctionDefinition const*> getInterfaceFunctions() const; + /// List of all (direct and indirect) base contracts in order from derived to base, including + /// the contract itself. Available after name resolution + std::vector const& getLinearizedBaseContracts() const { return m_linearizedBaseContracts; } + void setLinearizedBaseContracts(std::vector const& _bases) { m_linearizedBaseContracts = _bases; } + /// Returns the constructor or nullptr if no constructor was specified FunctionDefinition const* getConstructor() const; private: - std::vector, FunctionDefinition const*>> getInterfaceFunctionList() const; + void checkIllegalOverrides() const; + + std::vector, FunctionDefinition const*>> const& getInterfaceFunctionList() const; + std::vector> m_baseContracts; std::vector> m_definedStructs; std::vector> m_stateVariables; std::vector> m_definedFunctions; ASTPointer m_documentation; + + std::vector m_linearizedBaseContracts; + mutable std::unique_ptr, FunctionDefinition const*>>> m_interfaceFunctionList; +}; + +class InheritanceSpecifier: public ASTNode +{ +public: + InheritanceSpecifier(Location const& _location, ASTPointer const& _baseName, + std::vector> _arguments): + ASTNode(_location), m_baseName(_baseName), m_arguments(_arguments) {} + + virtual void accept(ASTVisitor& _visitor) override; + virtual void accept(ASTConstVisitor& _visitor) const override; + + ASTPointer const& getName() const { return m_baseName; } + std::vector> const& getArguments() const { return m_arguments; } + + void checkTypeRequirements(); + +private: + ASTPointer m_baseName; + std::vector> m_arguments; }; class StructDefinition: public Declaration @@ -581,7 +615,7 @@ public: virtual void accept(ASTConstVisitor& _visitor) const override; virtual void checkTypeRequirements() override; - void setFunctionReturnParameters(ParameterList& _parameters) { m_returnParameters = &_parameters; } + void setFunctionReturnParameters(ParameterList const& _parameters) { m_returnParameters = &_parameters; } ParameterList const& getFunctionReturnParameters() const { solAssert(m_returnParameters, ""); @@ -593,7 +627,7 @@ private: ASTPointer m_expression; ///< value to return, optional /// Pointer to the parameter list of the function, filled by the @ref NameAndTypeResolver. - ParameterList* m_returnParameters; + ParameterList const* m_returnParameters; }; /** @@ -870,21 +904,30 @@ class Identifier: public PrimaryExpression { public: Identifier(Location const& _location, ASTPointer const& _name): - PrimaryExpression(_location), m_name(_name), m_referencedDeclaration(nullptr) {} + PrimaryExpression(_location), m_name(_name) {} virtual void accept(ASTVisitor& _visitor) override; virtual void accept(ASTConstVisitor& _visitor) const override; virtual void checkTypeRequirements() override; ASTString const& getName() const { return *m_name; } - void setReferencedDeclaration(Declaration const& _referencedDeclaration) { m_referencedDeclaration = &_referencedDeclaration; } + void setReferencedDeclaration(Declaration const& _referencedDeclaration, + ContractDefinition const* _currentContract = nullptr) + { + m_referencedDeclaration = &_referencedDeclaration; + m_currentContract = _currentContract; + } Declaration const* getReferencedDeclaration() const { return m_referencedDeclaration; } + ContractDefinition const* getCurrentContract() const { return m_currentContract; } private: ASTPointer m_name; /// Declaration the name refers to. - Declaration const* m_referencedDeclaration; + Declaration const* m_referencedDeclaration = nullptr; + /// Stores a reference to the current contract. This is needed because types of base contracts + /// change depending on the context. + ContractDefinition const* m_currentContract = nullptr; }; /** diff --git a/libsolidity/ASTForward.h b/libsolidity/ASTForward.h index c960fc8f0..da0a88122 100644 --- a/libsolidity/ASTForward.h +++ b/libsolidity/ASTForward.h @@ -38,6 +38,7 @@ class SourceUnit; class ImportDirective; class Declaration; class ContractDefinition; +class InheritanceSpecifier; class StructDefinition; class ParameterList; class FunctionDefinition; diff --git a/libsolidity/AST_accept.h b/libsolidity/AST_accept.h index 7f3db85a1..b77cfe1c6 100644 --- a/libsolidity/AST_accept.h +++ b/libsolidity/AST_accept.h @@ -61,6 +61,7 @@ void ContractDefinition::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) { + listAccept(m_baseContracts, _visitor); listAccept(m_definedStructs, _visitor); listAccept(m_stateVariables, _visitor); listAccept(m_definedFunctions, _visitor); @@ -72,6 +73,7 @@ void ContractDefinition::accept(ASTConstVisitor& _visitor) const { if (_visitor.visit(*this)) { + listAccept(m_baseContracts, _visitor); listAccept(m_definedStructs, _visitor); listAccept(m_stateVariables, _visitor); listAccept(m_definedFunctions, _visitor); @@ -79,6 +81,26 @@ void ContractDefinition::accept(ASTConstVisitor& _visitor) const _visitor.endVisit(*this); } +void InheritanceSpecifier::accept(ASTVisitor& _visitor) +{ + if (_visitor.visit(*this)) + { + m_baseName->accept(_visitor); + listAccept(m_arguments, _visitor); + } + _visitor.endVisit(*this); +} + +void InheritanceSpecifier::accept(ASTConstVisitor& _visitor) const +{ + if (_visitor.visit(*this)) + { + m_baseName->accept(_visitor); + listAccept(m_arguments, _visitor); + } + _visitor.endVisit(*this); +} + void StructDefinition::accept(ASTVisitor& _visitor) { if (_visitor.visit(*this)) diff --git a/libsolidity/CallGraph.cpp b/libsolidity/CallGraph.cpp index b30afb612..88d874f3b 100644 --- a/libsolidity/CallGraph.cpp +++ b/libsolidity/CallGraph.cpp @@ -31,13 +31,9 @@ namespace dev namespace solidity { -void CallGraph::addFunction(FunctionDefinition const& _function) +void CallGraph::addNode(ASTNode const& _node) { - if (!m_functionsSeen.count(&_function)) - { - m_functionsSeen.insert(&_function); - m_workQueue.push(&_function); - } + _node.accept(*this); } set const& CallGraph::getCalls() @@ -63,5 +59,41 @@ bool CallGraph::visit(Identifier const& _identifier) return true; } +bool CallGraph::visit(FunctionDefinition const& _function) +{ + addFunction(_function); + return true; +} + +bool CallGraph::visit(MemberAccess const& _memberAccess) +{ + // used for "BaseContract.baseContractFunction" + if (_memberAccess.getExpression().getType()->getCategory() == Type::Category::TYPE) + { + TypeType const& type = dynamic_cast(*_memberAccess.getExpression().getType()); + if (type.getMembers().getMemberType(_memberAccess.getMemberName())) + { + ContractDefinition const& contract = dynamic_cast(*type.getActualType()) + .getContractDefinition(); + for (ASTPointer const& function: contract.getDefinedFunctions()) + if (function->getName() == _memberAccess.getMemberName()) + { + addFunction(*function); + return true; + } + } + } + return true; +} + +void CallGraph::addFunction(FunctionDefinition const& _function) +{ + if (!m_functionsSeen.count(&_function)) + { + m_functionsSeen.insert(&_function); + m_workQueue.push(&_function); + } +} + } } diff --git a/libsolidity/CallGraph.h b/libsolidity/CallGraph.h index f7af64bff..e3558fc25 100644 --- a/libsolidity/CallGraph.h +++ b/libsolidity/CallGraph.h @@ -38,14 +38,17 @@ namespace solidity class CallGraph: private ASTConstVisitor { public: - void addFunction(FunctionDefinition const& _function); + void addNode(ASTNode const& _node); void computeCallGraph(); std::set const& getCalls(); private: - void addFunctionToQueue(FunctionDefinition const& _function); + virtual bool visit(FunctionDefinition const& _function) override; virtual bool visit(Identifier const& _identifier) override; + virtual bool visit(MemberAccess const& _memberAccess) override; + + void addFunction(FunctionDefinition const& _function); std::set m_functionsSeen; std::queue m_workQueue; diff --git a/libsolidity/Compiler.cpp b/libsolidity/Compiler.cpp index bd6571b9a..36316b9ae 100644 --- a/libsolidity/Compiler.cpp +++ b/libsolidity/Compiler.cpp @@ -21,6 +21,7 @@ */ #include +#include #include #include #include @@ -34,48 +35,84 @@ using namespace std; namespace dev { namespace solidity { -void Compiler::compileContract(ContractDefinition const& _contract, vector const& _magicGlobals, +void Compiler::compileContract(ContractDefinition const& _contract, map const& _contracts) { m_context = CompilerContext(); // clear it just in case - initializeContext(_contract, _magicGlobals, _contracts); + initializeContext(_contract, _contracts); - for (ASTPointer const& function: _contract.getDefinedFunctions()) - if (function->getName() != _contract.getName()) // don't add the constructor here - m_context.addFunction(*function); + for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) + for (ASTPointer const& function: contract->getDefinedFunctions()) + if (function->getName() != contract->getName()) // don't add the constructor here + m_context.addFunction(*function); appendFunctionSelector(_contract); - for (ASTPointer const& function: _contract.getDefinedFunctions()) - if (function->getName() != _contract.getName()) // don't add the constructor here - function->accept(*this); + for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) + for (ASTPointer const& function: contract->getDefinedFunctions()) + if (function->getName() != contract->getName()) // don't add the constructor here + function->accept(*this); // Swap the runtime context with the creation-time context swap(m_context, m_runtimeContext); - initializeContext(_contract, _magicGlobals, _contracts); + initializeContext(_contract, _contracts); packIntoContractCreator(_contract, m_runtimeContext); } -void Compiler::initializeContext(ContractDefinition const& _contract, vector const& _magicGlobals, +void Compiler::initializeContext(ContractDefinition const& _contract, map const& _contracts) { m_context.setCompiledContracts(_contracts); - for (MagicVariableDeclaration const* variable: _magicGlobals) - m_context.addMagicGlobal(*variable); registerStateVariables(_contract); } void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) { + // arguments for base constructors, filled in derived-to-base order + map> const*> baseArguments; set neededFunctions; - FunctionDefinition const* constructor = _contract.getConstructor(); - if (constructor) - neededFunctions = getFunctionsNeededByConstructor(*constructor); + set nodesUsedInConstructors; + + // Determine the arguments that are used for the base constructors and also which functions + // are needed at compile time. + std::vector const& bases = _contract.getLinearizedBaseContracts(); + for (ContractDefinition const* contract: bases) + { + if (FunctionDefinition const* constructor = contract->getConstructor()) + nodesUsedInConstructors.insert(constructor); + for (ASTPointer const& base: contract->getBaseContracts()) + { + ContractDefinition const* baseContract = dynamic_cast( + base->getName()->getReferencedDeclaration()); + solAssert(baseContract, ""); + if (baseArguments.count(baseContract) == 0) + { + baseArguments[baseContract] = &base->getArguments(); + for (ASTPointer const& arg: base->getArguments()) + nodesUsedInConstructors.insert(arg.get()); + } + } + } + + //@TODO add virtual functions + neededFunctions = getFunctionsCalled(nodesUsedInConstructors); for (FunctionDefinition const* fun: neededFunctions) m_context.addFunction(*fun); - if (constructor) - appendConstructorCall(*constructor); + // Call constructors in base-to-derived order. + // The Constructor for the most derived contract is called later. + for (unsigned i = 1; i < bases.size(); i++) + { + ContractDefinition const* base = bases[bases.size() - i]; + solAssert(base, ""); + FunctionDefinition const* baseConstructor = base->getConstructor(); + if (!baseConstructor) + continue; + solAssert(baseArguments[base], ""); + appendBaseConstructorCall(*baseConstructor, *baseArguments[base]); + } + if (_contract.getConstructor()) + appendConstructorCall(*_contract.getConstructor()); eth::AssemblyItem sub = m_context.addSubroutine(_runtimeContext.getAssembly()); // stack contains sub size @@ -88,6 +125,21 @@ void Compiler::packIntoContractCreator(ContractDefinition const& _contract, Comp fun->accept(*this); } +void Compiler::appendBaseConstructorCall(FunctionDefinition const& _constructor, + vector> const& _arguments) +{ + FunctionType constructorType(_constructor); + eth::AssemblyItem returnLabel = m_context.pushNewTag(); + for (unsigned i = 0; i < _arguments.size(); ++i) + { + compileExpression(*_arguments[i]); + ExpressionCompiler::appendTypeConversion(m_context, *_arguments[i]->getType(), + *constructorType.getParameterTypes()[i]); + } + m_context.appendJumpTo(m_context.getFunctionEntryLabel(_constructor)); + m_context << returnLabel; +} + void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) { eth::AssemblyItem returnTag = m_context.pushNewTag(); @@ -107,11 +159,12 @@ void Compiler::appendConstructorCall(FunctionDefinition const& _constructor) m_context << returnTag; } -set Compiler::getFunctionsNeededByConstructor(FunctionDefinition const& _constructor) +set Compiler::getFunctionsCalled(set const& _nodes) { + // TODO this does not add virtual functions CallGraph callgraph; - callgraph.addFunction(_constructor); - callgraph.computeCallGraph(); + for (ASTNode const* node: _nodes) + callgraph.addNode(*node); return callgraph.getCalls(); } @@ -193,9 +246,9 @@ void Compiler::appendReturnValuePacker(FunctionDefinition const& _function) void Compiler::registerStateVariables(ContractDefinition const& _contract) { - //@todo sort them? - for (ASTPointer const& variable: _contract.getStateVariables()) - m_context.addStateVariable(*variable); + for (ContractDefinition const* contract: boost::adaptors::reverse(_contract.getLinearizedBaseContracts())) + for (ASTPointer const& variable: contract->getStateVariables()) + m_context.addStateVariable(*variable); } bool Compiler::visit(FunctionDefinition const& _function) diff --git a/libsolidity/Compiler.h b/libsolidity/Compiler.h index c229a7a8a..ea05f38ee 100644 --- a/libsolidity/Compiler.h +++ b/libsolidity/Compiler.h @@ -32,23 +32,24 @@ class Compiler: private ASTConstVisitor public: explicit Compiler(bool _optimize = false): m_optimize(_optimize), m_context(), m_returnTag(m_context.newTag()) {} - void compileContract(ContractDefinition const& _contract, std::vector const& _magicGlobals, + void compileContract(ContractDefinition const& _contract, std::map const& _contracts); bytes getAssembledBytecode() { return m_context.getAssembledBytecode(m_optimize); } bytes getRuntimeBytecode() { return m_runtimeContext.getAssembledBytecode(m_optimize);} void streamAssembly(std::ostream& _stream) const { m_context.streamAssembly(_stream); } private: - /// Registers the global objects and the non-function objects inside the contract with the context. - void initializeContext(ContractDefinition const& _contract, std::vector const& _magicGlobals, + /// Registers the non-function objects inside the contract with the context. + void initializeContext(ContractDefinition const& _contract, std::map const& _contracts); /// Adds the code that is run at creation time. Should be run after exchanging the run-time context - /// with a new and initialized context. - /// adds the constructor code. + /// with a new and initialized context. Adds the constructor code. void packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext); + void appendBaseConstructorCall(FunctionDefinition const& _constructor, + std::vector> const& _arguments); void appendConstructorCall(FunctionDefinition const& _constructor); - /// Recursively searches the call graph and returns all functions needed by the constructor (including itself). - std::set getFunctionsNeededByConstructor(FunctionDefinition const& _constructor); + /// Recursively searches the call graph and returns all functions referenced inside _nodes. + std::set getFunctionsCalled(std::set const& _nodes); void appendFunctionSelector(ContractDefinition const& _contract); /// Creates code that unpacks the arguments for the given function, from memory if /// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes. diff --git a/libsolidity/CompilerContext.cpp b/libsolidity/CompilerContext.cpp index 29e98eabf..27ec3efd5 100644 --- a/libsolidity/CompilerContext.cpp +++ b/libsolidity/CompilerContext.cpp @@ -61,7 +61,9 @@ void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _decla void CompilerContext::addFunction(FunctionDefinition const& _function) { - m_functionEntryLabels.insert(std::make_pair(&_function, m_asm.newTag())); + eth::AssemblyItem tag(m_asm.newTag()); + m_functionEntryLabels.insert(make_pair(&_function, tag)); + m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag)); } bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const @@ -83,6 +85,13 @@ eth::AssemblyItem CompilerContext::getFunctionEntryLabel(FunctionDefinition cons return res->second.tag(); } +eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const +{ + auto res = m_virtualFunctionEntryLabels.find(_function.getName()); + solAssert(res != m_virtualFunctionEntryLabels.end(), "Function entry label not found."); + return res->second.tag(); +} + unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const { auto res = m_localVariables.find(&_declaration); diff --git a/libsolidity/CompilerContext.h b/libsolidity/CompilerContext.h index cf505d654..cde992d58 100644 --- a/libsolidity/CompilerContext.h +++ b/libsolidity/CompilerContext.h @@ -57,6 +57,8 @@ public: bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; } eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const; + /// @returns the entry label of the given function and takes overrides into account. + eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const; /// Returns the distance of the given local variable from the top of the local variable stack. unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; /// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns @@ -116,6 +118,8 @@ private: unsigned m_localVariablesSize; /// Labels pointing to the entry points of funcitons. std::map m_functionEntryLabels; + /// Labels pointing to the entry points of function overrides. + std::map m_virtualFunctionEntryLabels; }; } diff --git a/libsolidity/CompilerStack.cpp b/libsolidity/CompilerStack.cpp index 5532d74bc..790eb983a 100644 --- a/libsolidity/CompilerStack.cpp +++ b/libsolidity/CompilerStack.cpp @@ -113,10 +113,8 @@ void CompilerStack::compile(bool _optimize) for (ASTPointer const& node: source->ast->getNodes()) if (ContractDefinition* contract = dynamic_cast(node.get())) { - m_globalContext->setCurrentContract(*contract); shared_ptr compiler = make_shared(_optimize); - compiler->compileContract(*contract, m_globalContext->getMagicVariables(), - contractBytecode); + compiler->compileContract(*contract, contractBytecode); Contract& compiledContract = m_contracts[contract->getName()]; compiledContract.bytecode = compiler->getAssembledBytecode(); compiledContract.runtimeBytecode = compiler->getRuntimeBytecode(); diff --git a/libsolidity/DeclarationContainer.h b/libsolidity/DeclarationContainer.h index c0a0b42c7..e4b793259 100644 --- a/libsolidity/DeclarationContainer.h +++ b/libsolidity/DeclarationContainer.h @@ -42,11 +42,12 @@ public: explicit DeclarationContainer(Declaration const* _enclosingDeclaration = nullptr, DeclarationContainer const* _enclosingContainer = nullptr): m_enclosingDeclaration(_enclosingDeclaration), m_enclosingContainer(_enclosingContainer) {} - /// Registers the declaration in the scope unless its name is already declared. Returns true iff - /// it was not yet declared. + /// Registers the declaration in the scope unless its name is already declared. + /// @returns true iff it was not yet declared. bool registerDeclaration(Declaration const& _declaration, bool _update = false); Declaration const* resolveName(ASTString const& _name, bool _recursive = false) const; Declaration const* getEnclosingDeclaration() const { return m_enclosingDeclaration; } + std::map const& getDeclarations() const { return m_declarations; } private: Declaration const* m_enclosingDeclaration; diff --git a/libsolidity/ExpressionCompiler.cpp b/libsolidity/ExpressionCompiler.cpp index df90d0d9d..60c5c4ded 100644 --- a/libsolidity/ExpressionCompiler.cpp +++ b/libsolidity/ExpressionCompiler.cpp @@ -419,6 +419,22 @@ void ExpressionCompiler::endVisit(MemberAccess const& _memberAccess) m_currentLValue.retrieveValueIfLValueNotRequested(_memberAccess); break; } + case Type::Category::TYPE: + { + TypeType const& type = dynamic_cast(*_memberAccess.getExpression().getType()); + if (type.getMembers().getMemberType(member)) + { + ContractDefinition const& contract = dynamic_cast(*type.getActualType()) + .getContractDefinition(); + for (ASTPointer const& function: contract.getDefinedFunctions()) + if (function->getName() == member) + { + m_context << m_context.getFunctionEntryLabel(*function).pushTag(); + return; + } + } + BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Invalid member access to " + type.toString())); + } default: BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Member access to unknown type.")); } @@ -449,20 +465,22 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier) { if (magicVar->getType()->getCategory() == Type::Category::CONTRACT) // must be "this" m_context << eth::Instruction::ADDRESS; - return; - } - if (FunctionDefinition const* functionDef = dynamic_cast(declaration)) - { - m_context << m_context.getFunctionEntryLabel(*functionDef).pushTag(); - return; } - if (dynamic_cast(declaration)) + else if (FunctionDefinition const* functionDef = dynamic_cast(declaration)) + m_context << m_context.getVirtualFunctionEntryLabel(*functionDef).pushTag(); + else if (dynamic_cast(declaration)) { m_currentLValue.fromIdentifier(_identifier, *declaration); m_currentLValue.retrieveValueIfLValueNotRequested(_identifier); - return; } - BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Identifier type not expected in expression context.")); + else if (dynamic_cast(declaration)) + { + // no-op + } + else + { + BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Identifier type not expected in expression context.")); + } } void ExpressionCompiler::endVisit(Literal const& _literal) diff --git a/libsolidity/GlobalContext.cpp b/libsolidity/GlobalContext.cpp index 92ca9548a..c7eea92dc 100644 --- a/libsolidity/GlobalContext.cpp +++ b/libsolidity/GlobalContext.cpp @@ -68,7 +68,7 @@ void GlobalContext::setCurrentContract(ContractDefinition const& _contract) vector GlobalContext::getDeclarations() const { vector declarations; - declarations.reserve(m_magicVariables.size() + 1); + declarations.reserve(m_magicVariables.size()); for (ASTPointer const& variable: m_magicVariables) declarations.push_back(variable.get()); return declarations; @@ -83,15 +83,5 @@ MagicVariableDeclaration const* GlobalContext::getCurrentThis() const } -vector GlobalContext::getMagicVariables() const -{ - vector declarations; - declarations.reserve(m_magicVariables.size() + 1); - for (ASTPointer const& variable: m_magicVariables) - declarations.push_back(variable.get()); - declarations.push_back(getCurrentThis()); - return declarations; -} - } } diff --git a/libsolidity/GlobalContext.h b/libsolidity/GlobalContext.h index c6e35f504..dfdc66623 100644 --- a/libsolidity/GlobalContext.h +++ b/libsolidity/GlobalContext.h @@ -49,8 +49,6 @@ public: void setCurrentContract(ContractDefinition const& _contract); MagicVariableDeclaration const* getCurrentThis() const; - /// @returns all magic variables. - std::vector getMagicVariables() const; /// @returns a vector of all implicit global declarations excluding "this". std::vector getDeclarations() const; diff --git a/libsolidity/NameAndTypeResolver.cpp b/libsolidity/NameAndTypeResolver.cpp index 3774537d1..ba5ca1345 100644 --- a/libsolidity/NameAndTypeResolver.cpp +++ b/libsolidity/NameAndTypeResolver.cpp @@ -31,7 +31,6 @@ namespace dev namespace solidity { - NameAndTypeResolver::NameAndTypeResolver(std::vector const& _globals) { for (Declaration const* declaration: _globals) @@ -46,18 +45,27 @@ void NameAndTypeResolver::registerDeclarations(SourceUnit& _sourceUnit) void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract) { + m_currentScope = &m_scopes[nullptr]; + + for (ASTPointer const& baseContract: _contract.getBaseContracts()) + ReferencesResolver resolver(*baseContract, *this, &_contract, nullptr); + m_currentScope = &m_scopes[&_contract]; + + linearizeBaseContracts(_contract); + for (ContractDefinition const* base: _contract.getLinearizedBaseContracts()) + importInheritedScope(*base); + for (ASTPointer const& structDef: _contract.getDefinedStructs()) - ReferencesResolver resolver(*structDef, *this, nullptr); + ReferencesResolver resolver(*structDef, *this, &_contract, nullptr); for (ASTPointer const& variable: _contract.getStateVariables()) - ReferencesResolver resolver(*variable, *this, nullptr); + ReferencesResolver resolver(*variable, *this, &_contract, nullptr); for (ASTPointer const& function: _contract.getDefinedFunctions()) { m_currentScope = &m_scopes[function.get()]; - ReferencesResolver referencesResolver(*function, *this, + ReferencesResolver referencesResolver(*function, *this, &_contract, function->getReturnParameterList().get()); } - m_currentScope = &m_scopes[nullptr]; } void NameAndTypeResolver::checkTypeRequirements(ContractDefinition& _contract) @@ -86,6 +94,96 @@ Declaration const* NameAndTypeResolver::getNameFromCurrentScope(ASTString const& return m_currentScope->resolveName(_name, _recursive); } +void NameAndTypeResolver::importInheritedScope(ContractDefinition const& _base) +{ + auto iterator = m_scopes.find(&_base); + solAssert(iterator != end(m_scopes), ""); + for (auto const& nameAndDeclaration: iterator->second.getDeclarations()) + { + Declaration const* declaration = nameAndDeclaration.second; + // Import if it was declared in the base and is not the constructor + if (declaration->getScope() == &_base && declaration->getName() != _base.getName()) + m_currentScope->registerDeclaration(*declaration); + } +} + +void NameAndTypeResolver::linearizeBaseContracts(ContractDefinition& _contract) const +{ + // order in the lists is from derived to base + // list of lists to linearize, the last element is the list of direct bases + list> input(1, {&_contract}); + for (ASTPointer const& baseSpecifier: _contract.getBaseContracts()) + { + ASTPointer baseName = baseSpecifier->getName(); + ContractDefinition const* base = dynamic_cast( + baseName->getReferencedDeclaration()); + if (!base) + BOOST_THROW_EXCEPTION(baseName->createTypeError("Contract expected.")); + // "push_back" has the effect that bases mentioned earlier can overwrite members of bases + // mentioned later + input.back().push_back(base); + vector const& basesBases = base->getLinearizedBaseContracts(); + if (basesBases.empty()) + BOOST_THROW_EXCEPTION(baseName->createTypeError("Definition of base has to precede definition of derived contract")); + input.push_front(list(basesBases.begin(), basesBases.end())); + } + vector result = cThreeMerge(input); + if (result.empty()) + BOOST_THROW_EXCEPTION(_contract.createTypeError("Linearization of inheritance graph impossible")); + _contract.setLinearizedBaseContracts(result); +} + +template +vector<_T const*> NameAndTypeResolver::cThreeMerge(list>& _toMerge) +{ + // returns true iff _candidate appears only as last element of the lists + auto appearsOnlyAtHead = [&](_T const* _candidate) -> bool + { + for (list<_T const*> const& bases: _toMerge) + { + solAssert(!bases.empty(), ""); + if (find(++bases.begin(), bases.end(), _candidate) != bases.end()) + return false; + } + return true; + }; + // returns the next candidate to append to the linearized list or nullptr on failure + auto nextCandidate = [&]() -> _T const* + { + for (list<_T const*> const& bases: _toMerge) + { + solAssert(!bases.empty(), ""); + if (appearsOnlyAtHead(bases.front())) + return bases.front(); + } + return nullptr; + }; + // removes the given contract from all lists + auto removeCandidate = [&](_T const* _candidate) + { + for (auto it = _toMerge.begin(); it != _toMerge.end();) + { + it->remove(_candidate); + if (it->empty()) + it = _toMerge.erase(it); + else + ++it; + } + }; + + _toMerge.remove_if([](list<_T const*> const& _bases) { return _bases.empty(); }); + vector<_T const*> result; + while (!_toMerge.empty()) + { + _T const* candidate = nextCandidate(); + if (!candidate) + return vector<_T const*>(); + result.push_back(candidate); + removeCandidate(candidate); + } + return result; +} + DeclarationRegistrationHelper::DeclarationRegistrationHelper(map& _scopes, ASTNode& _astRoot): m_scopes(_scopes), m_currentScope(nullptr) @@ -169,8 +267,10 @@ void DeclarationRegistrationHelper::registerDeclaration(Declaration& _declaratio } ReferencesResolver::ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, - ParameterList* _returnParameters, bool _allowLazyTypes): - m_resolver(_resolver), m_returnParameters(_returnParameters), m_allowLazyTypes(_allowLazyTypes) + ContractDefinition const* _currentContract, + ParameterList const* _returnParameters, bool _allowLazyTypes): + m_resolver(_resolver), m_currentContract(_currentContract), + m_returnParameters(_returnParameters), m_allowLazyTypes(_allowLazyTypes) { _root.accept(*this); } @@ -218,7 +318,7 @@ bool ReferencesResolver::visit(Identifier& _identifier) if (!declaration) BOOST_THROW_EXCEPTION(DeclarationError() << errinfo_sourceLocation(_identifier.getLocation()) << errinfo_comment("Undeclared identifier.")); - _identifier.setReferencedDeclaration(*declaration); + _identifier.setReferencedDeclaration(*declaration, m_currentContract); return false; } diff --git a/libsolidity/NameAndTypeResolver.h b/libsolidity/NameAndTypeResolver.h index 1032a87cf..f97c7ae56 100644 --- a/libsolidity/NameAndTypeResolver.h +++ b/libsolidity/NameAndTypeResolver.h @@ -23,6 +23,7 @@ #pragma once #include +#include #include #include @@ -64,6 +65,17 @@ public: private: void reset(); + /// Imports all members declared directly in the given contract (i.e. does not import inherited + /// members) into the current scope if they are not present already. + void importInheritedScope(ContractDefinition const& _base); + + /// Computes "C3-Linearization" of base contracts and stores it inside the contract. + void linearizeBaseContracts(ContractDefinition& _contract) const; + /// Computes the C3-merge of the given list of lists of bases. + /// @returns the linearized vector or an empty vector if linearization is not possible. + template + static std::vector<_T const*> cThreeMerge(std::list>& _toMerge); + /// Maps nodes declaring a scope to scopes, i.e. ContractDefinition and FunctionDeclaration, /// where nullptr denotes the global scope. Note that structs are not scope since they do /// not contain code. @@ -108,7 +120,9 @@ class ReferencesResolver: private ASTVisitor { public: ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, - ParameterList* _returnParameters, bool _allowLazyTypes = true); + ContractDefinition const* _currentContract, + ParameterList const* _returnParameters, + bool _allowLazyTypes = true); private: virtual void endVisit(VariableDeclaration& _variable) override; @@ -118,7 +132,8 @@ private: virtual bool visit(Return& _return) override; NameAndTypeResolver& m_resolver; - ParameterList* m_returnParameters; + ContractDefinition const* m_currentContract; + ParameterList const* m_returnParameters; bool m_allowLazyTypes; }; diff --git a/libsolidity/Parser.cpp b/libsolidity/Parser.cpp index ebff3ba40..c0ca1abb2 100644 --- a/libsolidity/Parser.cpp +++ b/libsolidity/Parser.cpp @@ -117,10 +117,18 @@ ASTPointer Parser::parseContractDefinition() docstring = make_shared(m_scanner->getCurrentCommentLiteral()); expectToken(Token::CONTRACT); ASTPointer name = expectIdentifierToken(); - expectToken(Token::LBRACE); + vector> baseContracts; vector> structs; vector> stateVariables; vector> functions; + if (m_scanner->getCurrentToken() == Token::IS) + do + { + m_scanner->next(); + baseContracts.push_back(parseInheritanceSpecifier()); + } + while (m_scanner->getCurrentToken() == Token::COMMA); + expectToken(Token::LBRACE); bool visibilityIsPublic = true; while (true) { @@ -149,7 +157,25 @@ ASTPointer Parser::parseContractDefinition() } nodeFactory.markEndPosition(); expectToken(Token::RBRACE); - return nodeFactory.createNode(name, docstring, structs, stateVariables, functions); + return nodeFactory.createNode(name, docstring, baseContracts, structs, + stateVariables, functions); +} + +ASTPointer Parser::parseInheritanceSpecifier() +{ + ASTNodeFactory nodeFactory(*this); + ASTPointer name = ASTNodeFactory(*this).createNode(expectIdentifierToken()); + vector> arguments; + if (m_scanner->getCurrentToken() == Token::LPAREN) + { + m_scanner->next(); + arguments = parseFunctionCallArguments(); + nodeFactory.markEndPosition(); + expectToken(Token::RPAREN); + } + else + nodeFactory.setEndPositionFromNode(name); + return nodeFactory.createNode(name, arguments); } ASTPointer Parser::parseFunctionDefinition(bool _isPublic) diff --git a/libsolidity/Parser.h b/libsolidity/Parser.h index bf3a6beac..1b7a980ff 100644 --- a/libsolidity/Parser.h +++ b/libsolidity/Parser.h @@ -49,6 +49,7 @@ private: ///@name Parsing functions for the AST nodes ASTPointer parseImportDirective(); ASTPointer parseContractDefinition(); + ASTPointer parseInheritanceSpecifier(); ASTPointer parseFunctionDefinition(bool _isPublic); ASTPointer parseStructDefinition(); ASTPointer parseVariableDeclaration(bool _allowVar); diff --git a/libsolidity/Token.h b/libsolidity/Token.h index 9fb86e7f4..552e9a75e 100644 --- a/libsolidity/Token.h +++ b/libsolidity/Token.h @@ -153,7 +153,7 @@ namespace solidity K(DEFAULT, "default", 0) \ K(DO, "do", 0) \ K(ELSE, "else", 0) \ - K(EXTENDS, "extends", 0) \ + K(IS, "is", 0) \ K(FOR, "for", 0) \ K(FUNCTION, "function", 0) \ K(IF, "if", 0) \ diff --git a/libsolidity/Types.cpp b/libsolidity/Types.cpp index a99e4853f..c6d8b62fc 100644 --- a/libsolidity/Types.cpp +++ b/libsolidity/Types.cpp @@ -431,12 +431,19 @@ bool ContractType::isImplicitlyConvertibleTo(Type const& _convertTo) const return true; if (_convertTo.getCategory() == Category::INTEGER) return dynamic_cast(_convertTo).isAddress(); + if (_convertTo.getCategory() == Category::CONTRACT) + { + auto const& bases = getContractDefinition().getLinearizedBaseContracts(); + return find(bases.begin(), bases.end(), + &dynamic_cast(_convertTo).getContractDefinition()) != bases.end(); + } return false; } bool ContractType::isExplicitlyConvertibleTo(Type const& _convertTo) const { - return isImplicitlyConvertibleTo(_convertTo) || _convertTo.getCategory() == Category::INTEGER; + return isImplicitlyConvertibleTo(_convertTo) || _convertTo.getCategory() == Category::INTEGER || + _convertTo.getCategory() == Category::CONTRACT; } TypePointer ContractType::unaryOperatorResult(Token::Value _operator) const @@ -695,6 +702,29 @@ bool TypeType::operator==(Type const& _other) const return *getActualType() == *other.getActualType(); } +MemberList const& TypeType::getMembers() const +{ + // We need to lazy-initialize it because of recursive references. + if (!m_members) + { + map members; + if (m_actualType->getCategory() == Category::CONTRACT && m_currentContract != nullptr) + { + ContractDefinition const& contract = dynamic_cast(*m_actualType).getContractDefinition(); + vector currentBases = m_currentContract->getLinearizedBaseContracts(); + if (find(currentBases.begin(), currentBases.end(), &contract) != currentBases.end()) + // We are accessing the type of a base contract, so add all public and private + // functions. Note that this does not add inherited functions on purpose. + for (ASTPointer const& f: contract.getDefinedFunctions()) + if (f->getName() != contract.getName()) + members[f->getName()] = make_shared(*f); + } + m_members.reset(new MemberList(members)); + } + return *m_members; +} + + MagicType::MagicType(MagicType::Kind _kind): m_kind(_kind) { diff --git a/libsolidity/Types.h b/libsolidity/Types.h index cb8a8db0d..e6c99fe3b 100644 --- a/libsolidity/Types.h +++ b/libsolidity/Types.h @@ -442,7 +442,8 @@ class TypeType: public Type { public: virtual Category getCategory() const override { return Category::TYPE; } - TypeType(TypePointer const& _actualType): m_actualType(_actualType) {} + TypeType(TypePointer const& _actualType, ContractDefinition const* _currentContract = nullptr): + m_actualType(_actualType), m_currentContract(_currentContract) {} TypePointer const& getActualType() const { return m_actualType; } virtual TypePointer binaryOperatorResult(Token::Value, TypePointer const&) const override { return TypePointer(); } @@ -451,9 +452,14 @@ public: virtual u256 getStorageSize() const override { BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Storage size of non-storable type type requested.")); } virtual bool canLiveOutsideStorage() const override { return false; } virtual std::string toString() const override { return "type(" + m_actualType->toString() + ")"; } + virtual MemberList const& getMembers() const override; private: TypePointer m_actualType; + /// Context in which this type is used (influences visibility etc.), can be nullptr. + ContractDefinition const* m_currentContract; + /// List of member types, will be lazy-initialized because of recursive references. + mutable std::unique_ptr m_members; }; diff --git a/libsolidity/grammar.txt b/libsolidity/grammar.txt index f06d4def2..11d99854c 100644 --- a/libsolidity/grammar.txt +++ b/libsolidity/grammar.txt @@ -1,7 +1,10 @@ -ContractDefinition = 'contract' Identifier '{' ContractPart* '}' +ContractDefinition = 'contract' Identifier + ( 'is' InheritanceSpecifier (',' InheritanceSpecifier )* )? + '{' ContractPart* '}' ContractPart = VariableDeclaration ';' | StructDefinition | FunctionDefinition | 'public:' | 'private:' +InheritanceSpecifier = Identifier ( '(' Expression ( ',' Expression )* ')' )? StructDefinition = 'struct' Identifier '{' ( VariableDeclaration (';' VariableDeclaration)* )? '} diff --git a/test/SolidityCompiler.cpp b/test/SolidityCompiler.cpp index b4874e195..53daa9dfe 100644 --- a/test/SolidityCompiler.cpp +++ b/test/SolidityCompiler.cpp @@ -64,7 +64,7 @@ bytes compileContract(const string& _sourceCode) if (ContractDefinition* contract = dynamic_cast(node.get())) { Compiler compiler; - compiler.compileContract(*contract, {}, map{}); + compiler.compileContract(*contract, map{}); // debug //compiler.streamAssembly(cout); diff --git a/test/SolidityEndToEndTest.cpp b/test/SolidityEndToEndTest.cpp index dd57a1d15..cba926d6d 100644 --- a/test/SolidityEndToEndTest.cpp +++ b/test/SolidityEndToEndTest.cpp @@ -1493,6 +1493,141 @@ BOOST_AUTO_TEST_CASE(value_for_constructor) BOOST_REQUIRE(callContractFunction("getBalances()") == encodeArgs(12, 10)); } +BOOST_AUTO_TEST_CASE(virtual_function_calls) +{ + char const* sourceCode = R"( + contract Base { + function f() returns (uint i) { return g(); } + function g() returns (uint i) { return 1; } + } + contract Derived is Base { + function g() returns (uint i) { return 2; } + } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("g()") == encodeArgs(2)); + BOOST_CHECK(callContractFunction("f()") == encodeArgs(2)); +} + +BOOST_AUTO_TEST_CASE(access_base_storage) +{ + char const* sourceCode = R"( + contract Base { + uint dataBase; + function getViaBase() returns (uint i) { return dataBase; } + } + contract Derived is Base { + uint dataDerived; + function setData(uint base, uint derived) returns (bool r) { + dataBase = base; + dataDerived = derived; + return true; + } + function getViaDerived() returns (uint base, uint derived) { + base = dataBase; + derived = dataDerived; + } + } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("setData(uint256,uint256)", 1, 2) == encodeArgs(true)); + BOOST_CHECK(callContractFunction("getViaBase()") == encodeArgs(1)); + BOOST_CHECK(callContractFunction("getViaDerived()") == encodeArgs(1, 2)); +} + +BOOST_AUTO_TEST_CASE(single_copy_with_multiple_inheritance) +{ + char const* sourceCode = R"( + contract Base { + uint data; + function setData(uint i) { data = i; } + function getViaBase() returns (uint i) { return data; } + } + contract A is Base { function setViaA(uint i) { setData(i); } } + contract B is Base { function getViaB() returns (uint i) { return getViaBase(); } } + contract Derived is A, B, Base { } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("getViaB()") == encodeArgs(0)); + BOOST_CHECK(callContractFunction("setViaA(uint256)", 23) == encodeArgs()); + BOOST_CHECK(callContractFunction("getViaB()") == encodeArgs(23)); +} + +BOOST_AUTO_TEST_CASE(explicit_base_cass) +{ + char const* sourceCode = R"( + contract BaseBase { function g() returns (uint r) { return 1; } } + contract Base is BaseBase { function g() returns (uint r) { return 2; } } + contract Derived is Base { + function f() returns (uint r) { return BaseBase.g(); } + function g() returns (uint r) { return 3; } + } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("g()") == encodeArgs(3)); + BOOST_CHECK(callContractFunction("f()") == encodeArgs(1)); +} + +BOOST_AUTO_TEST_CASE(base_constructor_arguments) +{ + char const* sourceCode = R"( + contract BaseBase { + uint m_a; + function BaseBase(uint a) { + m_a = a; + } + } + contract Base is BaseBase(7) { + function Base() { + m_a *= m_a; + } + } + contract Derived is Base() { + function getA() returns (uint r) { return m_a; } + } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("getA()") == encodeArgs(7 * 7)); +} + +BOOST_AUTO_TEST_CASE(function_usage_in_constructor_arguments) +{ + char const* sourceCode = R"( + contract BaseBase { + uint m_a; + function BaseBase(uint a) { + m_a = a; + } + function g() returns (uint r) { return 2; } + } + contract Base is BaseBase(BaseBase.g()) { + } + contract Derived is Base() { + function getA() returns (uint r) { return m_a; } + } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("getA()") == encodeArgs(2)); +} + +BOOST_AUTO_TEST_CASE(constructor_argument_overriding) +{ + char const* sourceCode = R"( + contract BaseBase { + uint m_a; + function BaseBase(uint a) { + m_a = a; + } + } + contract Base is BaseBase(2) { } + contract Derived is Base, BaseBase(3) { + function getA() returns (uint r) { return m_a; } + } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("getA()") == encodeArgs(3)); +} + BOOST_AUTO_TEST_SUITE_END() } diff --git a/test/SolidityNameAndTypeResolution.cpp b/test/SolidityNameAndTypeResolution.cpp index e2b4f160d..6c8fd1b1c 100644 --- a/test/SolidityNameAndTypeResolution.cpp +++ b/test/SolidityNameAndTypeResolution.cpp @@ -357,7 +357,6 @@ BOOST_AUTO_TEST_CASE(function_canonical_signature_type_aliases) } } - BOOST_AUTO_TEST_CASE(hash_collision_in_interface) { char const* text = "contract test {\n" @@ -369,6 +368,128 @@ BOOST_AUTO_TEST_CASE(hash_collision_in_interface) BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError); } +BOOST_AUTO_TEST_CASE(inheritance_basic) +{ + char const* text = R"( + contract base { uint baseMember; struct BaseType { uint element; } } + contract derived is base { + BaseType data; + function f() { baseMember = 7; } + } + )"; + BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text)); +} + +BOOST_AUTO_TEST_CASE(inheritance_diamond_basic) +{ + char const* text = R"( + contract root { function rootFunction() {} } + contract inter1 is root { function f() {} } + contract inter2 is root { function f() {} } + contract derived is inter1, inter2, root { + function g() { f(); rootFunction(); } + } + )"; + BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text)); +} + +BOOST_AUTO_TEST_CASE(cyclic_inheritance) +{ + char const* text = R"( + contract A is B { } + contract B is A { } + )"; + BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError); +} + +BOOST_AUTO_TEST_CASE(illegal_override_direct) +{ + char const* text = R"( + contract B { function f() {} } + contract C is B { function f(uint i) {} } + )"; + BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError); +} + +BOOST_AUTO_TEST_CASE(illegal_override_indirect) +{ + char const* text = R"( + contract A { function f(uint a) {} } + contract B { function f() {} } + contract C is A, B { } + )"; + BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError); +} + +BOOST_AUTO_TEST_CASE(complex_inheritance) +{ + char const* text = R"( + contract A { function f() { uint8 x = C(0).g(); } } + contract B { function f() {} function g() returns (uint8 r) {} } + contract C is A, B { } + )"; + BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text)); +} + +BOOST_AUTO_TEST_CASE(constructor_visibility) +{ + // The constructor of a base class should not be visible in the derived class + char const* text = R"( + contract A { function A() { } } + contract B is A { function f() { A x = A(0); } } + )"; + BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text)); +} + +BOOST_AUTO_TEST_CASE(overriding_constructor) +{ + // It is fine to "override" constructor of a base class since it is invisible + char const* text = R"( + contract A { function A() { } } + contract B is A { function A() returns (uint8 r) {} } + )"; + BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text)); +} + +BOOST_AUTO_TEST_CASE(missing_base_constructor_arguments) +{ + char const* text = R"( + contract A { function A(uint a) { } } + contract B is A { } + )"; + BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError); +} + +BOOST_AUTO_TEST_CASE(base_constructor_arguments_override) +{ + char const* text = R"( + contract A { function A(uint a) { } } + contract B is A { } + )"; + BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError); +} + +BOOST_AUTO_TEST_CASE(implicit_derived_to_base_conversion) +{ + char const* text = R"( + contract A { } + contract B is A { + function f() { A a = B(1); } + } + )"; + BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text)); +} +BOOST_AUTO_TEST_CASE(implicit_base_to_derived_conversion) +{ + char const* text = R"( + contract A { } + contract B is A { + function f() { B b = A(1); } + } + )"; + BOOST_CHECK_THROW(parseTextAndResolveNames(text), TypeError); +} + BOOST_AUTO_TEST_SUITE_END() } diff --git a/test/SolidityParser.cpp b/test/SolidityParser.cpp index 86f935c39..91e571306 100644 --- a/test/SolidityParser.cpp +++ b/test/SolidityParser.cpp @@ -495,6 +495,51 @@ BOOST_AUTO_TEST_CASE(multiple_contracts_and_imports) BOOST_CHECK_NO_THROW(parseText(text)); } +BOOST_AUTO_TEST_CASE(contract_inheritance) +{ + char const* text = "contract base {\n" + " function fun() {\n" + " uint64(2);\n" + " }\n" + "}\n" + "contract derived is base {\n" + " function fun() {\n" + " uint64(2);\n" + " }\n" + "}\n"; + BOOST_CHECK_NO_THROW(parseText(text)); +} + +BOOST_AUTO_TEST_CASE(contract_multiple_inheritance) +{ + char const* text = "contract base {\n" + " function fun() {\n" + " uint64(2);\n" + " }\n" + "}\n" + "contract derived is base, nonExisting {\n" + " function fun() {\n" + " uint64(2);\n" + " }\n" + "}\n"; + BOOST_CHECK_NO_THROW(parseText(text)); +} + +BOOST_AUTO_TEST_CASE(contract_multiple_inheritance_with_arguments) +{ + char const* text = "contract base {\n" + " function fun() {\n" + " uint64(2);\n" + " }\n" + "}\n" + "contract derived is base(2), nonExisting(\"abc\", \"def\", base.fun()) {\n" + " function fun() {\n" + " uint64(2);\n" + " }\n" + "}\n"; + BOOST_CHECK_NO_THROW(parseText(text)); +} + BOOST_AUTO_TEST_SUITE_END() }