Browse Source

Inheritance in compiler.

cl-refactor
Christian 10 years ago
parent
commit
df44090ae6
  1. 24
      libsolidity/Compiler.cpp
  2. 11
      libsolidity/CompilerContext.cpp
  3. 4
      libsolidity/CompilerContext.h
  4. 2
      libsolidity/ExpressionCompiler.cpp
  5. 60
      test/SolidityEndToEndTest.cpp

24
libsolidity/Compiler.cpp

@ -21,6 +21,7 @@
*/ */
#include <algorithm> #include <algorithm>
#include <boost/range/adaptor/reversed.hpp>
#include <libevmcore/Instruction.h> #include <libevmcore/Instruction.h>
#include <libevmcore/Assembly.h> #include <libevmcore/Assembly.h>
#include <libsolidity/AST.h> #include <libsolidity/AST.h>
@ -40,14 +41,16 @@ void Compiler::compileContract(ContractDefinition const& _contract,
m_context = CompilerContext(); // clear it just in case m_context = CompilerContext(); // clear it just in case
initializeContext(_contract, _contracts); initializeContext(_contract, _contracts);
for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
if (function->getName() != _contract.getName()) // don't add the constructor here for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
m_context.addFunction(*function); if (function->getName() != contract->getName()) // don't add the constructor here
m_context.addFunction(*function);
appendFunctionSelector(_contract); appendFunctionSelector(_contract);
for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) for (ContractDefinition const* contract: _contract.getLinearizedBaseContracts())
if (function->getName() != _contract.getName()) // don't add the constructor here for (ASTPointer<FunctionDefinition> const& function: contract->getDefinedFunctions())
function->accept(*this); if (function->getName() != contract->getName()) // don't add the constructor here
function->accept(*this);
// Swap the runtime context with the creation-time context // Swap the runtime context with the creation-time context
swap(m_context, m_runtimeContext); swap(m_context, m_runtimeContext);
@ -65,13 +68,16 @@ void Compiler::initializeContext(ContractDefinition const& _contract,
void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext) void Compiler::packIntoContractCreator(ContractDefinition const& _contract, CompilerContext const& _runtimeContext)
{ {
set<FunctionDefinition const*> neededFunctions; set<FunctionDefinition const*> neededFunctions;
// TODO constructors of base classes
FunctionDefinition const* constructor = _contract.getConstructor(); FunctionDefinition const* constructor = _contract.getConstructor();
if (constructor) if (constructor)
neededFunctions = getFunctionsNeededByConstructor(*constructor); neededFunctions = getFunctionsNeededByConstructor(*constructor);
// TODO we should add the overridden functions
for (FunctionDefinition const* fun: neededFunctions) for (FunctionDefinition const* fun: neededFunctions)
m_context.addFunction(*fun); m_context.addFunction(*fun);
// we have many of them now
if (constructor) if (constructor)
appendConstructorCall(*constructor); appendConstructorCall(*constructor);
@ -191,9 +197,9 @@ void Compiler::appendReturnValuePacker(FunctionDefinition const& _function)
void Compiler::registerStateVariables(ContractDefinition const& _contract) void Compiler::registerStateVariables(ContractDefinition const& _contract)
{ {
//@todo sort them? for (ContractDefinition const* contract: boost::adaptors::reverse(_contract.getLinearizedBaseContracts()))
for (ASTPointer<VariableDeclaration> const& variable: _contract.getStateVariables()) for (ASTPointer<VariableDeclaration> const& variable: contract->getStateVariables())
m_context.addStateVariable(*variable); m_context.addStateVariable(*variable);
} }
bool Compiler::visit(FunctionDefinition const& _function) bool Compiler::visit(FunctionDefinition const& _function)

11
libsolidity/CompilerContext.cpp

@ -61,7 +61,9 @@ void CompilerContext::addAndInitializeVariable(VariableDeclaration const& _decla
void CompilerContext::addFunction(FunctionDefinition const& _function) void CompilerContext::addFunction(FunctionDefinition const& _function)
{ {
m_functionEntryLabels.insert(std::make_pair(&_function, m_asm.newTag())); eth::AssemblyItem tag(m_asm.newTag());
m_functionEntryLabels.insert(make_pair(&_function, tag));
m_virtualFunctionEntryLabels.insert(make_pair(_function.getName(), tag));
} }
bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const bytes const& CompilerContext::getCompiledContract(const ContractDefinition& _contract) const
@ -83,6 +85,13 @@ eth::AssemblyItem CompilerContext::getFunctionEntryLabel(FunctionDefinition cons
return res->second.tag(); return res->second.tag();
} }
eth::AssemblyItem CompilerContext::getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const
{
auto res = m_virtualFunctionEntryLabels.find(_function.getName());
solAssert(res != m_virtualFunctionEntryLabels.end(), "Function entry label not found.");
return res->second.tag();
}
unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const unsigned CompilerContext::getBaseStackOffsetOfVariable(Declaration const& _declaration) const
{ {
auto res = m_localVariables.find(&_declaration); auto res = m_localVariables.find(&_declaration);

4
libsolidity/CompilerContext.h

@ -57,6 +57,8 @@ public:
bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; } bool isStateVariable(Declaration const* _declaration) const { return m_stateVariables.count(_declaration) != 0; }
eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const; eth::AssemblyItem getFunctionEntryLabel(FunctionDefinition const& _function) const;
/// @returns the entry label of the given function and takes overrides into account.
eth::AssemblyItem getVirtualFunctionEntryLabel(FunctionDefinition const& _function) const;
/// Returns the distance of the given local variable from the top of the local variable stack. /// Returns the distance of the given local variable from the top of the local variable stack.
unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const; unsigned getBaseStackOffsetOfVariable(Declaration const& _declaration) const;
/// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns /// If supplied by a value returned by @ref getBaseStackOffsetOfVariable(variable), returns
@ -116,6 +118,8 @@ private:
unsigned m_localVariablesSize; unsigned m_localVariablesSize;
/// Labels pointing to the entry points of funcitons. /// Labels pointing to the entry points of funcitons.
std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels; std::map<Declaration const*, eth::AssemblyItem> m_functionEntryLabels;
/// Labels pointing to the entry points of function overrides.
std::map<std::string, eth::AssemblyItem> m_virtualFunctionEntryLabels;
}; };
} }

2
libsolidity/ExpressionCompiler.cpp

@ -453,7 +453,7 @@ void ExpressionCompiler::endVisit(Identifier const& _identifier)
} }
if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration)) if (FunctionDefinition const* functionDef = dynamic_cast<FunctionDefinition const*>(declaration))
{ {
m_context << m_context.getFunctionEntryLabel(*functionDef).pushTag(); m_context << m_context.getVirtualFunctionEntryLabel(*functionDef).pushTag();
return; return;
} }
if (dynamic_cast<VariableDeclaration const*>(declaration)) if (dynamic_cast<VariableDeclaration const*>(declaration))

60
test/SolidityEndToEndTest.cpp

@ -1493,6 +1493,66 @@ BOOST_AUTO_TEST_CASE(value_for_constructor)
BOOST_REQUIRE(callContractFunction("getBalances()") == encodeArgs(12, 10)); BOOST_REQUIRE(callContractFunction("getBalances()") == encodeArgs(12, 10));
} }
BOOST_AUTO_TEST_CASE(virtual_function_calls)
{
char const* sourceCode = R"(
contract Base {
function f() returns (uint i) { return g(); }
function g() returns (uint i) { return 1; }
}
contract Derived is Base {
function g() returns (uint i) { return 2; }
}
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("g()") == encodeArgs(2));
BOOST_CHECK(callContractFunction("f()") == encodeArgs(2));
}
BOOST_AUTO_TEST_CASE(access_base_storage)
{
char const* sourceCode = R"(
contract Base {
uint dataBase;
function getViaBase() returns (uint i) { return dataBase; }
}
contract Derived is Base {
uint dataDerived;
function setData(uint base, uint derived) returns (bool r) {
dataBase = base;
dataDerived = derived;
return true;
}
function getViaDerived() returns (uint base, uint derived) {
base = dataBase;
derived = dataDerived;
}
}
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("setData(uint256,uint256)", 1, 2) == encodeArgs(true));
BOOST_CHECK(callContractFunction("getViaBase()") == encodeArgs(1));
BOOST_CHECK(callContractFunction("getViaDerived()") == encodeArgs(1, 2));
}
BOOST_AUTO_TEST_CASE(single_copy_with_multiple_inheritance)
{
char const* sourceCode = R"(
contract Base {
uint data;
function setData(uint i) { data = i; }
function getViaBase() returns (uint i) { return data; }
}
contract A is Base { function setViaA(uint i) { setData(i); } }
contract B is Base { function getViaB() returns (uint i) { return getViaBase(); } }
contract Derived is A, B, Base { }
)";
compileAndRun(sourceCode, 0, "Derived");
BOOST_CHECK(callContractFunction("getViaB()") == encodeArgs(0));
BOOST_CHECK(callContractFunction("setViaA(uint256)", 23) == encodeArgs());
BOOST_CHECK(callContractFunction("getViaB()") == encodeArgs(23));
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
} }

Loading…
Cancel
Save