Browse Source

Merge pull request #528 from chriseth/sol_constructor

Special handling for constructor.
cl-refactor
Gav Wood 10 years ago
parent
commit
9320e76952
  1. 93
      libsolidity/Compiler.cpp
  2. 13
      libsolidity/Compiler.h
  3. 22
      test/solidityEndToEndTest.cpp

93
libsolidity/Compiler.cpp

@ -43,42 +43,59 @@ void Compiler::compileContract(ContractDefinition& _contract)
{ {
m_context = CompilerContext(); // clear it just in case m_context = CompilerContext(); // clear it just in case
//@todo constructor
for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions())
m_context.addFunction(*function); if (function->getName() != _contract.getName()) // don't add the constructor here
//@todo sort them? m_context.addFunction(*function);
for (ASTPointer<VariableDeclaration> const& variable: _contract.getStateVariables()) registerStateVariables(_contract);
m_context.addStateVariable(*variable);
appendFunctionSelector(_contract.getDefinedFunctions()); appendFunctionSelector(_contract);
for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions()) for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions())
function->accept(*this); if (function->getName() != _contract.getName()) // don't add the constructor here
function->accept(*this);
packIntoContractCreator(); packIntoContractCreator(_contract);
} }
void Compiler::packIntoContractCreator() void Compiler::packIntoContractCreator(ContractDefinition const& _contract)
{ {
CompilerContext creatorContext; CompilerContext runtimeContext;
eth::AssemblyItem sub = creatorContext.addSubroutine(m_context.getAssembly()); swap(m_context, runtimeContext);
registerStateVariables(_contract);
FunctionDefinition* constructor = nullptr;
for (ASTPointer<FunctionDefinition> const& function: _contract.getDefinedFunctions())
if (function->getName() == _contract.getName())
{
constructor = function.get();
break;
}
if (constructor)
{
eth::AssemblyItem returnTag = m_context.pushNewTag();
m_context.addFunction(*constructor); // note that it cannot be called due to syntactic reasons
//@todo copy constructor arguments from calldata to memory prior to this
//@todo calling other functions inside the constructor should either trigger a parse error
//or we should copy them here (register them above and call "accept") - detecting which
// functions are referenced / called needs to be done in a recursive way.
appendCalldataUnpacker(*constructor, true);
m_context.appendJumpTo(m_context.getFunctionEntryLabel(*constructor));
constructor->accept(*this);
m_context << returnTag;
}
eth::AssemblyItem sub = m_context.addSubroutine(runtimeContext.getAssembly());
// stack contains sub size // stack contains sub size
creatorContext << eth::Instruction::DUP1 << sub << u256(0) << eth::Instruction::CODECOPY; m_context << eth::Instruction::DUP1 << sub << u256(0) << eth::Instruction::CODECOPY;
creatorContext << u256(0) << eth::Instruction::RETURN; m_context << u256(0) << eth::Instruction::RETURN;
swap(m_context, creatorContext);
} }
void Compiler::appendFunctionSelector(vector<ASTPointer<FunctionDefinition>> const& _functions) void Compiler::appendFunctionSelector(ContractDefinition const& _contract)
{ {
// sort all public functions and store them together with a tag for their argument decoding section vector<FunctionDefinition const*> interfaceFunctions = _contract.getInterfaceFunctions();
map<string, pair<FunctionDefinition const*, eth::AssemblyItem>> publicFunctions; vector<eth::AssemblyItem> callDataUnpackerEntryPoints;
for (ASTPointer<FunctionDefinition> const& f: _functions)
if (f->isPublic())
publicFunctions.insert(make_pair(f->getName(), make_pair(f.get(), m_context.newTag())));
//@todo remove constructor if (interfaceFunctions.size() > 255)
if (publicFunctions.size() > 255)
BOOST_THROW_EXCEPTION(CompilerError() << errinfo_comment("More than 255 public functions for contract.")); BOOST_THROW_EXCEPTION(CompilerError() << errinfo_comment("More than 255 public functions for contract."));
// retrieve the first byte of the call data, which determines the called function // retrieve the first byte of the call data, which determines the called function
@ -90,21 +107,20 @@ void Compiler::appendFunctionSelector(vector<ASTPointer<FunctionDefinition>> con
<< eth::dupInstruction(2); << eth::dupInstruction(2);
// stack here: 1 0 <funid> 0, stack top will be counted up until it matches funid // stack here: 1 0 <funid> 0, stack top will be counted up until it matches funid
for (pair<string, pair<FunctionDefinition const*, eth::AssemblyItem>> const& f: publicFunctions) for (unsigned funid = 0; funid < interfaceFunctions.size(); ++funid)
{ {
eth::AssemblyItem const& callDataUnpackerEntry = f.second.second; callDataUnpackerEntryPoints.push_back(m_context.newTag());
m_context << eth::dupInstruction(2) << eth::dupInstruction(2) << eth::Instruction::EQ; m_context << eth::dupInstruction(2) << eth::dupInstruction(2) << eth::Instruction::EQ;
m_context.appendConditionalJumpTo(callDataUnpackerEntry); m_context.appendConditionalJumpTo(callDataUnpackerEntryPoints.back());
m_context << eth::dupInstruction(4) << eth::Instruction::ADD; m_context << eth::dupInstruction(4) << eth::Instruction::ADD;
//@todo avoid the last ADD (or remove it in the optimizer) //@todo avoid the last ADD (or remove it in the optimizer)
} }
m_context << eth::Instruction::STOP; // function not found m_context << eth::Instruction::STOP; // function not found
for (pair<string, pair<FunctionDefinition const*, eth::AssemblyItem>> const& f: publicFunctions) for (unsigned funid = 0; funid < interfaceFunctions.size(); ++funid)
{ {
FunctionDefinition const& function = *f.second.first; FunctionDefinition const& function = *interfaceFunctions[funid];
eth::AssemblyItem const& callDataUnpackerEntry = f.second.second; m_context << callDataUnpackerEntryPoints[funid];
m_context << callDataUnpackerEntry;
eth::AssemblyItem returnTag = m_context.pushNewTag(); eth::AssemblyItem returnTag = m_context.pushNewTag();
appendCalldataUnpacker(function); appendCalldataUnpacker(function);
m_context.appendJumpTo(m_context.getFunctionEntryLabel(function)); m_context.appendJumpTo(m_context.getFunctionEntryLabel(function));
@ -113,10 +129,11 @@ void Compiler::appendFunctionSelector(vector<ASTPointer<FunctionDefinition>> con
} }
} }
void Compiler::appendCalldataUnpacker(FunctionDefinition const& _function) unsigned Compiler::appendCalldataUnpacker(FunctionDefinition const& _function, bool _fromMemory)
{ {
// We do not check the calldata size, everything is zero-padded. // We do not check the calldata size, everything is zero-padded.
unsigned dataOffset = 1; unsigned dataOffset = 1;
eth::Instruction load = _fromMemory ? eth::Instruction::MLOAD : eth::Instruction::CALLDATALOAD;
//@todo this can be done more efficiently, saving some CALLDATALOAD calls //@todo this can be done more efficiently, saving some CALLDATALOAD calls
for (ASTPointer<VariableDeclaration> const& var: _function.getParameters()) for (ASTPointer<VariableDeclaration> const& var: _function.getParameters())
@ -127,12 +144,13 @@ void Compiler::appendCalldataUnpacker(FunctionDefinition const& _function)
<< errinfo_sourceLocation(var->getLocation()) << errinfo_sourceLocation(var->getLocation())
<< errinfo_comment("Type " + var->getType()->toString() + " not yet supported.")); << errinfo_comment("Type " + var->getType()->toString() + " not yet supported."));
if (numBytes == 32) if (numBytes == 32)
m_context << u256(dataOffset) << eth::Instruction::CALLDATALOAD; m_context << u256(dataOffset) << load;
else else
m_context << (u256(1) << ((32 - numBytes) * 8)) << u256(dataOffset) m_context << (u256(1) << ((32 - numBytes) * 8)) << u256(dataOffset)
<< eth::Instruction::CALLDATALOAD << eth::Instruction::DIV; << load << eth::Instruction::DIV;
dataOffset += numBytes; dataOffset += numBytes;
} }
return dataOffset;
} }
void Compiler::appendReturnValuePacker(FunctionDefinition const& _function) void Compiler::appendReturnValuePacker(FunctionDefinition const& _function)
@ -158,6 +176,13 @@ void Compiler::appendReturnValuePacker(FunctionDefinition const& _function)
m_context << u256(dataOffset) << u256(0) << eth::Instruction::RETURN; m_context << u256(dataOffset) << u256(0) << eth::Instruction::RETURN;
} }
void Compiler::registerStateVariables(ContractDefinition const& _contract)
{
//@todo sort them?
for (ASTPointer<VariableDeclaration> const& variable: _contract.getStateVariables())
m_context.addStateVariable(*variable);
}
bool Compiler::visit(FunctionDefinition& _function) bool Compiler::visit(FunctionDefinition& _function)
{ {
//@todo to simplify this, the calling convention could by changed such that //@todo to simplify this, the calling convention could by changed such that

13
libsolidity/Compiler.h

@ -40,12 +40,17 @@ public:
static bytes compile(ContractDefinition& _contract, bool _optimize); static bytes compile(ContractDefinition& _contract, bool _optimize);
private: private:
/// Creates a new compiler context / assembly and packs the current code into the data part. /// Creates a new compiler context / assembly, packs the current code into the data part and
void packIntoContractCreator(); /// adds the constructor code.
void appendFunctionSelector(std::vector<ASTPointer<FunctionDefinition> > const& _functions); void packIntoContractCreator(ContractDefinition const& _contract);
void appendCalldataUnpacker(FunctionDefinition const& _function); void appendFunctionSelector(ContractDefinition const& _contract);
/// Creates code that unpacks the arguments for the given function, from memory if
/// @a _fromMemory is true, otherwise from call data. @returns the size of the data in bytes.
unsigned appendCalldataUnpacker(FunctionDefinition const& _function, bool _fromMemory = false);
void appendReturnValuePacker(FunctionDefinition const& _function); void appendReturnValuePacker(FunctionDefinition const& _function);
void registerStateVariables(ContractDefinition const& _contract);
virtual bool visit(FunctionDefinition& _function) override; virtual bool visit(FunctionDefinition& _function) override;
virtual bool visit(IfStatement& _ifStatement) override; virtual bool visit(IfStatement& _ifStatement) override;
virtual bool visit(WhileStatement& _whileStatement) override; virtual bool visit(WhileStatement& _whileStatement) override;

22
test/solidityEndToEndTest.cpp

@ -700,6 +700,28 @@ BOOST_AUTO_TEST_CASE(structs)
BOOST_CHECK(callContractFunction(0) == bytes({0x01})); BOOST_CHECK(callContractFunction(0) == bytes({0x01}));
} }
BOOST_AUTO_TEST_CASE(constructor)
{
char const* sourceCode = "contract test {\n"
" mapping(uint => uint) data;\n"
" function test() {\n"
" data[7] = 8;\n"
" }\n"
" function get(uint key) returns (uint value) {\n"
" return data[key];"
" }\n"
"}\n";
compileAndRun(sourceCode);
map<u256, byte> data;
data[7] = 8;
auto get = [&](u256 const& _x) -> u256
{
return data[_x];
};
testSolidityAgainstCpp(0, get, u256(6));
testSolidityAgainstCpp(0, get, u256(7));
}
BOOST_AUTO_TEST_SUITE_END() BOOST_AUTO_TEST_SUITE_END()
} }

Loading…
Cancel
Save