From df44090ae6dca65fb337457f661ca5ef974e0dac Mon Sep 17 00:00:00 2001 From: Christian Date: Thu, 15 Jan 2015 20:04:24 +0100 Subject: [PATCH] Inheritance in compiler. --- libsolidity/Compiler.cpp | 24 +++++++----- libsolidity/CompilerContext.cpp | 11 +++++- libsolidity/CompilerContext.h | 4 ++ libsolidity/ExpressionCompiler.cpp | 2 +- test/SolidityEndToEndTest.cpp | 60 ++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 11 deletions(-) diff --git a/libsolidity/Compiler.cpp b/libsolidity/Compiler.cpp index d86c939c8..aa3022aad 100644 --- a/libsolidity/Compiler.cpp +++ b/libsolidity/Compiler.cpp @@ -21,6 +21,7 @@ */ #include +#include #include #include #include @@ -40,14 +41,16 @@ void Compiler::compileContract(ContractDefinition const& _contract, m_context = CompilerContext(); // clear it just in case initializeContext(_contract, _contracts); - for (ASTPointer const& function: _contract.getDefinedFunctions()) - if (function->getName() != _contract.getName()) // don't add the constructor here - m_context.addFunction(*function); + for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) + for (ASTPointer const& function: contract->getDefinedFunctions()) + if (function->getName() != contract->getName()) // don't add the constructor here + m_context.addFunction(*function); appendFunctionSelector(_contract); - for (ASTPointer const& function: _contract.getDefinedFunctions()) - if (function->getName() != _contract.getName()) // don't add the constructor here - function->accept(*this); + for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts()) + for (ASTPointer const& function: contract->getDefinedFunctions()) + if (function->getName() != contract->getName()) // don't add the constructor here + function->accept(*this); // Swap the runtime context with the creation-time context swap(m_context, m_runtimeContext); @@ -65,13 +68,16 @@ void Compiler::initializeContext(ContractDefinition const& _contract, void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) { set neededFunctions; + // TODO constructors of base classes FunctionDefinition const* constructor = _contract.getConstructor(); if (constructor) neededFunctions = getFunctionsNeededByConstructor(*constructor); + // TODO we should add the overridden functions for (FunctionDefinition const* fun: neededFunctions) m_context.addFunction(*fun); + // we have many of them now if (constructor) appendConstructorCall(*constructor); @@ -191,9 +197,9 @@ void Compiler::appendReturnValuePacker(FunctionDefinition const& _function) void Compiler::registerStateVariables(ContractDefinition const& _contract) { - //@todo sort them? - for (ASTPointer const& variable: _contract.getStateVariables()) - m_context.addStateVariable(*variable); + for (ContractDefinition const* contract: boost::adaptors::reverse(_contract.getLinearizedBaseContracts())) + for (ASTPointer const& variable: contract->getStateVariables()) + m_context.addStateVariable(*variable); } bool Compiler::visit(FunctionDefinition const& _function) diff --git a/libsolidity/CompilerContext.cpp b/libsolidity/CompilerContext.cpp index 29e98eabf..27ec3efd5 100644 --- a/libsolidity/CompilerContext.cpp +++ b/libsolidity/CompilerContext.cpp @@ -61,7 +61,9 @@ void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _decla void CompilerContext::addFunction(FunctionDefinition const& _function) { - m_functionEntryLabels.insert(std::make_pair(&_function, m_asm.newTag())); + eth::AssemblyItem tag(m_asm.newTag()); + m_functionEntryLabels.insert(make_pair(&_function, tag)); + m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag)); } bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const @@ -83,6 +85,13 @@ eth::AssemblyItem CompilerContext::getFunctionEntryLabel(FunctionDefinition cons return res->second.tag(); } +eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const +{ + auto res = m_virtualFunctionEntryLabels.find(_function.getName()); + solAssert(res != m_virtualFunctionEntryLabels.end(), "Function entry label not found."); + return res->second.tag(); +} + unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const { auto res = m_localVariables.find(&_declaration); diff --git a/libsolidity/CompilerContext.h b/libsolidity/CompilerContext.h index cf505d654..cde992d58 100644 --- a/libsolidity/CompilerContext.h +++ b/libsolidity/CompilerContext.h @@ -57,6 +57,8 @@ public: bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; } eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const; + /// @returns the entry label of the given function and takes overrides into account. + eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const; /// Returns the distance of the given local variable from the top of the local variable stack. unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; /// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns @@ -116,6 +118,8 @@ private: unsigned m_localVariablesSize; /// Labels pointing to the entry points of funcitons. std::map m_functionEntryLabels; + /// Labels pointing to the entry points of function overrides. + std::map m_virtualFunctionEntryLabels; }; } diff --git a/libsolidity/ExpressionCompiler.cpp b/libsolidity/ExpressionCompiler.cpp index df90d0d9d..5a45bfd6d 100644 --- a/libsolidity/ExpressionCompiler.cpp +++ b/libsolidity/ExpressionCompiler.cpp @@ -453,7 +453,7 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier) } if (FunctionDefinition const* functionDef = dynamic_cast(declaration)) { - m_context << m_context.getFunctionEntryLabel(*functionDef).pushTag(); + m_context << m_context.getVirtualFunctionEntryLabel(*functionDef).pushTag(); return; } if (dynamic_cast(declaration)) diff --git a/test/SolidityEndToEndTest.cpp b/test/SolidityEndToEndTest.cpp index dd57a1d15..934b39ad6 100644 --- a/test/SolidityEndToEndTest.cpp +++ b/test/SolidityEndToEndTest.cpp @@ -1493,6 +1493,66 @@ BOOST_AUTO_TEST_CASE(value_for_constructor) BOOST_REQUIRE(callContractFunction("getBalances()") == encodeArgs(12, 10)); } +BOOST_AUTO_TEST_CASE(virtual_function_calls) +{ + char const* sourceCode = R"( + contract Base { + function f() returns (uint i) { return g(); } + function g() returns (uint i) { return 1; } + } + contract Derived is Base { + function g() returns (uint i) { return 2; } + } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("g()") == encodeArgs(2)); + BOOST_CHECK(callContractFunction("f()") == encodeArgs(2)); +} + +BOOST_AUTO_TEST_CASE(access_base_storage) +{ + char const* sourceCode = R"( + contract Base { + uint dataBase; + function getViaBase() returns (uint i) { return dataBase; } + } + contract Derived is Base { + uint dataDerived; + function setData(uint base, uint derived) returns (bool r) { + dataBase = base; + dataDerived = derived; + return true; + } + function getViaDerived() returns (uint base, uint derived) { + base = dataBase; + derived = dataDerived; + } + } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("setData(uint256,uint256)", 1, 2) == encodeArgs(true)); + BOOST_CHECK(callContractFunction("getViaBase()") == encodeArgs(1)); + BOOST_CHECK(callContractFunction("getViaDerived()") == encodeArgs(1, 2)); +} + +BOOST_AUTO_TEST_CASE(single_copy_with_multiple_inheritance) +{ + char const* sourceCode = R"( + contract Base { + uint data; + function setData(uint i) { data = i; } + function getViaBase() returns (uint i) { return data; } + } + contract A is Base { function setViaA(uint i) { setData(i); } } + contract B is Base { function getViaB() returns (uint i) { return getViaBase(); } } + contract Derived is A, B, Base { } + )"; + compileAndRun(sourceCode, 0, "Derived"); + BOOST_CHECK(callContractFunction("getViaB()") == encodeArgs(0)); + BOOST_CHECK(callContractFunction("setViaA(uint256)", 23) == encodeArgs()); + BOOST_CHECK(callContractFunction("getViaB()") == encodeArgs(23)); +} + BOOST_AUTO_TEST_SUITE_END() }