From bbc3a1b37e72d9d27a70c5d82eda219836017d3d Mon Sep 17 00:00:00 2001 From: Christian Date: Thu, 13 Nov 2014 01:12:57 +0100 Subject: [PATCH] Struct types. --- libsolidity/AST.cpp | 13 +++++- libsolidity/AST.h | 5 ++- libsolidity/ExpressionCompiler.cpp | 17 ++++++-- libsolidity/NameAndTypeResolver.cpp | 29 ++++++++++++-- libsolidity/NameAndTypeResolver.h | 11 ++++-- libsolidity/Types.cpp | 55 +++++++++++++++++++++++++- libsolidity/Types.h | 16 +++++--- test/solidityEndToEndTest.cpp | 37 +++++++++++++++++ test/solidityNameAndTypeResolution.cpp | 38 ++++++++++++++++++ 9 files changed, 201 insertions(+), 20 deletions(-) diff --git a/libsolidity/AST.cpp b/libsolidity/AST.cpp index d5f856dfc..e8bdecf31 100644 --- a/libsolidity/AST.cpp +++ b/libsolidity/AST.cpp @@ -460,8 +460,17 @@ bool FunctionCall::isTypeConversion() const void MemberAccess::checkTypeRequirements() { - BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Member access not yet implemented.")); - // m_type = ; + m_expression->checkTypeRequirements(); + m_expression->requireLValue(); + if (m_expression->getType()->getCategory() != Type::Category::STRUCT) + BOOST_THROW_EXCEPTION(createTypeError("Member access to a non-struct (is " + + m_expression->getType()->toString() + ")")); + StructType const& type = dynamic_cast(*m_expression->getType()); + unsigned memberIndex = type.memberNameToIndex(*m_memberName); + if (memberIndex >= type.getMemberCount()) + BOOST_THROW_EXCEPTION(createTypeError("Member \"" + *m_memberName + "\" not found in " + type.toString())); + m_type = type.getMemberByIndex(memberIndex).getType(); + m_isLvalue = true; } void IndexAccess::checkTypeRequirements() diff --git a/libsolidity/AST.h b/libsolidity/AST.h index 31ca56f76..80c7dd198 100644 --- a/libsolidity/AST.h +++ b/libsolidity/AST.h @@ -146,7 +146,7 @@ private: /** * Parameter list, used as function parameter list and return list. * None of the parameters is allowed to contain mappings (not even recursively - * inside structs), but (@todo) this is not yet enforced. + * inside structs). */ class ParameterList: public ASTNode { @@ -368,7 +368,6 @@ private: /** * Statement in which a break statement is legal. - * @todo actually check this requirement. */ class BreakableStatement: public Statement { @@ -629,6 +628,7 @@ public: ASTPointer const& _memberName): Expression(_location), m_expression(_expression), m_memberName(_memberName) {} virtual void accept(ASTVisitor& _visitor) override; + Expression& getExpression() const { return *m_expression; } ASTString const& getMemberName() const { return *m_memberName; } virtual void checkTypeRequirements() override; @@ -651,6 +651,7 @@ public: Expression& getBaseExpression() const { return *m_base; } Expression& getIndexExpression() const { return *m_index; } + private: ASTPointer m_base; ASTPointer m_index; diff --git a/libsolidity/ExpressionCompiler.cpp b/libsolidity/ExpressionCompiler.cpp index d80b42b35..f37ce39ce 100644 --- a/libsolidity/ExpressionCompiler.cpp +++ b/libsolidity/ExpressionCompiler.cpp @@ -49,7 +49,6 @@ bool ExpressionCompiler::visit(Assignment& _assignment) { _assignment.getRightHandSide().accept(*this); appendTypeConversion(*_assignment.getRightHandSide().getType(), *_assignment.getType()); - m_currentLValue.reset(); _assignment.getLeftHandSide().accept(*this); if (asserts(m_currentLValue.isValid())) BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("LValue not retrieved.")); @@ -63,6 +62,7 @@ bool ExpressionCompiler::visit(Assignment& _assignment) appendOrdinaryBinaryOperatorCode(Token::AssignmentToBinaryOp(op), *_assignment.getType()); } m_currentLValue.storeValue(_assignment); + m_currentLValue.reset(); return false; } @@ -90,6 +90,7 @@ void ExpressionCompiler::endVisit(UnaryOperation& _unaryOperation) if (m_currentLValue.storesReferenceOnStack()) m_context << eth::Instruction::SWAP1; m_currentLValue.storeValue(_unaryOperation); + m_currentLValue.reset(); break; case Token::INC: // ++ (pre- or postfix) case Token::DEC: // -- (pre- or postfix) @@ -113,6 +114,7 @@ void ExpressionCompiler::endVisit(UnaryOperation& _unaryOperation) if (m_currentLValue.storesReferenceOnStack()) m_context << eth::Instruction::SWAP1; m_currentLValue.storeValue(_unaryOperation, !_unaryOperation.isPrefixOperation()); + m_currentLValue.reset(); break; case Token::ADD: // + // unary add, so basically no-op @@ -182,10 +184,10 @@ bool ExpressionCompiler::visit(FunctionCall& _functionCall) arguments[i]->accept(*this); appendTypeConversion(*arguments[i]->getType(), *function.getParameters()[i]->getType()); } - m_currentLValue.reset(); _functionCall.getExpression().accept(*this); if (asserts(m_currentLValue.isInCode())) BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Code reference expected.")); + m_currentLValue.reset(); m_context.appendJump(); m_context << returnLabel; @@ -201,9 +203,16 @@ bool ExpressionCompiler::visit(FunctionCall& _functionCall) return false; } -void ExpressionCompiler::endVisit(MemberAccess&) +void ExpressionCompiler::endVisit(MemberAccess& _memberAccess) { - BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Member access not yet implemented.")); + if (asserts(m_currentLValue.isInStorage())) + BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Member access to a non-storage value.")); + StructType const& type = dynamic_cast(*_memberAccess.getExpression().getType()); + unsigned memberIndex = type.memberNameToIndex(_memberAccess.getMemberName()); + if (asserts(memberIndex <= type.getMemberCount())) + BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Member not found in struct during compilation.")); + m_context << type.getStorageOffsetOfMember(memberIndex) << eth::Instruction::ADD; + m_currentLValue.retrieveValueIfLValueNotRequested(_memberAccess); } bool ExpressionCompiler::visit(IndexAccess& _indexAccess) diff --git a/libsolidity/NameAndTypeResolver.cpp b/libsolidity/NameAndTypeResolver.cpp index 4a15fe794..5bc406855 100644 --- a/libsolidity/NameAndTypeResolver.cpp +++ b/libsolidity/NameAndTypeResolver.cpp @@ -37,7 +37,10 @@ void NameAndTypeResolver::resolveNamesAndTypes(ContractDefinition& _contract) reset(); DeclarationRegistrationHelper registrar(m_scopes, _contract); m_currentScope = &m_scopes[&_contract]; - //@todo structs + for (ASTPointer const& structDef: _contract.getDefinedStructs()) + ReferencesResolver resolver(*structDef, *this, nullptr); + for (ASTPointer const& structDef: _contract.getDefinedStructs()) + checkForRecursion(*structDef); for (ASTPointer const& variable: _contract.getStateVariables()) ReferencesResolver resolver(*variable, *this, nullptr); for (ASTPointer const& function: _contract.getDefinedFunctions()) @@ -70,6 +73,24 @@ Declaration* NameAndTypeResolver::getNameFromCurrentScope(ASTString const& _name return m_currentScope->resolveName(_name, _recursive); } +void NameAndTypeResolver::checkForRecursion(StructDefinition const& _struct) +{ + set definitionsSeen; + vector queue = {&_struct}; + while (!queue.empty()) + { + StructDefinition const* def = queue.back(); + queue.pop_back(); + if (definitionsSeen.count(def)) + BOOST_THROW_EXCEPTION(ParserError() << errinfo_sourceLocation(def->getLocation()) + << errinfo_comment("Recursive struct definition.")); + definitionsSeen.insert(def); + for (ASTPointer const& member: def->getMembers()) + if (member->getType()->getCategory() == Type::Category::STRUCT) + queue.push_back(dynamic_cast(*member->getTypeName()).getReferencedStruct()); + } +} + void NameAndTypeResolver::reset() { m_scopes.clear(); @@ -163,8 +184,8 @@ void DeclarationRegistrationHelper::registerDeclaration(Declaration& _declaratio } ReferencesResolver::ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, - ParameterList* _returnParameters): - m_resolver(_resolver), m_returnParameters(_returnParameters) + ParameterList* _returnParameters, bool _allowLazyTypes): + m_resolver(_resolver), m_returnParameters(_returnParameters), m_allowLazyTypes(_allowLazyTypes) { _root.accept(*this); } @@ -175,6 +196,8 @@ void ReferencesResolver::endVisit(VariableDeclaration& _variable) // or mapping if (_variable.getTypeName()) _variable.setType(_variable.getTypeName()->toType()); + else if (!m_allowLazyTypes) + BOOST_THROW_EXCEPTION(_variable.createTypeError("Explicit type needed.")); // otherwise we have a "var"-declaration whose type is resolved by the first assignment } diff --git a/libsolidity/NameAndTypeResolver.h b/libsolidity/NameAndTypeResolver.h index 909024942..d335807e5 100644 --- a/libsolidity/NameAndTypeResolver.h +++ b/libsolidity/NameAndTypeResolver.h @@ -55,10 +55,13 @@ public: Declaration* getNameFromCurrentScope(ASTString const& _name, bool _recursive = true); private: + /// Throws if @a _struct contains a recursive loop. Note that recursion via mappings is fine. + void checkForRecursion(StructDefinition const& _struct); void reset(); - /// Maps nodes declaring a scope to scopes, i.e. ContractDefinition, FunctionDeclaration and - /// StructDefinition (@todo not yet implemented), where nullptr denotes the global scope. + /// Maps nodes declaring a scope to scopes, i.e. ContractDefinition and FunctionDeclaration, + /// where nullptr denotes the global scope. Note that structs are not scope since they do + /// not contain code. std::map m_scopes; Scope* m_currentScope; @@ -99,7 +102,8 @@ private: class ReferencesResolver: private ASTVisitor { public: - ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, ParameterList* _returnParameters); + ReferencesResolver(ASTNode& _root, NameAndTypeResolver& _resolver, + ParameterList* _returnParameters, bool _allowLazyTypes = true); private: virtual void endVisit(VariableDeclaration& _variable) override; @@ -110,6 +114,7 @@ private: NameAndTypeResolver& m_resolver; ParameterList* m_returnParameters; + bool m_allowLazyTypes; }; } diff --git a/libsolidity/Types.cpp b/libsolidity/Types.cpp index e37ed3e5b..63bad5d61 100644 --- a/libsolidity/Types.cpp +++ b/libsolidity/Types.cpp @@ -80,7 +80,7 @@ shared_ptr Type::forLiteral(Literal const& _literal) case Token::NUMBER: return IntegerType::smallestTypeForLiteral(_literal.getValue()); case Token::STRING_LITERAL: - return shared_ptr(); // @todo + return shared_ptr(); // @todo add string literals default: return shared_ptr(); } @@ -231,6 +231,48 @@ u256 StructType::getStorageSize() const return max(1, size); } +bool StructType::canLiveOutsideStorage() const +{ + for (unsigned i = 0; i < getMemberCount(); ++i) + if (!getMemberByIndex(i).getType()->canLiveOutsideStorage()) + return false; + return true; +} + +string StructType::toString() const +{ + return string("struct ") + m_struct.getName(); +} + +unsigned StructType::getMemberCount() const +{ + return m_struct.getMembers().size(); +} + +unsigned StructType::memberNameToIndex(string const& _name) const +{ + vector> const& members = m_struct.getMembers(); + for (unsigned index = 0; index < members.size(); ++index) + if (members[index]->getName() == _name) + return index; + return unsigned(-1); +} + +VariableDeclaration const& StructType::getMemberByIndex(unsigned _index) const +{ + return *m_struct.getMembers()[_index]; +} + +u256 StructType::getStorageOffsetOfMember(unsigned _index) const +{ + //@todo cache member offset? + u256 offset; + vector> const& members = m_struct.getMembers(); + for (unsigned index = 0; index < _index; ++index) + offset += getMemberByIndex(index).getType()->getStorageSize(); + return offset; +} + bool FunctionType::operator==(Type const& _other) const { if (_other.getCategory() != getCategory()) @@ -239,6 +281,12 @@ bool FunctionType::operator==(Type const& _other) const return other.m_function == m_function; } +string FunctionType::toString() const +{ + //@todo nice string for function types + return "function(...)returns(...)"; +} + bool MappingType::operator==(Type const& _other) const { if (_other.getCategory() != getCategory()) @@ -247,6 +295,11 @@ bool MappingType::operator==(Type const& _other) const return *other.m_keyType == *m_keyType && *other.m_valueType == *m_valueType; } +string MappingType::toString() const +{ + return "mapping(" + getKeyType()->toString() + " => " + getValueType()->toString() + ")"; +} + bool TypeType::operator==(Type const& _other) const { if (_other.getCategory() != getCategory()) diff --git a/libsolidity/Types.h b/libsolidity/Types.h index b9bb74dbb..726470172 100644 --- a/libsolidity/Types.h +++ b/libsolidity/Types.h @@ -184,9 +184,14 @@ public: virtual bool operator==(Type const& _other) const override; virtual u256 getStorageSize() const; - //@todo it can, if its members can - virtual bool canLiveOutsideStorage() const { return false; } - virtual std::string toString() const override { return "struct{...}"; } + virtual bool canLiveOutsideStorage() const; + virtual std::string toString() const override; + + unsigned getMemberCount() const; + /// Returns the index of the member with name @a _name or unsigned(-1) if it does not exist. + unsigned memberNameToIndex(std::string const& _name) const; + VariableDeclaration const& getMemberByIndex(unsigned _index) const; + u256 getStorageOffsetOfMember(unsigned _index) const; private: StructDefinition const& m_struct; @@ -204,7 +209,7 @@ public: FunctionDefinition const& getFunction() const { return m_function; } virtual bool operator==(Type const& _other) const override; - virtual std::string toString() const override { return "function(...)returns(...)"; } + virtual std::string toString() const override; virtual u256 getStorageSize() const { BOOST_THROW_EXCEPTION(InternalCompilerError() << errinfo_comment("Storage size of non-storable function type requested.")); } virtual bool canLiveOutsideStorage() const { return false; } @@ -223,11 +228,12 @@ public: m_keyType(_keyType), m_valueType(_valueType) {} virtual bool operator==(Type const& _other) const override; - virtual std::string toString() const override { return "mapping(...=>...)"; } + virtual std::string toString() const override; virtual bool canLiveOutsideStorage() const { return false; } std::shared_ptr getKeyType() const { return m_keyType; } std::shared_ptr getValueType() const { return m_valueType; } + private: std::shared_ptr m_keyType; std::shared_ptr m_valueType; diff --git a/test/solidityEndToEndTest.cpp b/test/solidityEndToEndTest.cpp index 5e6631df6..617cbabc9 100644 --- a/test/solidityEndToEndTest.cpp +++ b/test/solidityEndToEndTest.cpp @@ -663,6 +663,43 @@ BOOST_AUTO_TEST_CASE(multi_level_mapping) testSolidityAgainstCpp(0, f, u256(5), u256(4), u256(0)); } +BOOST_AUTO_TEST_CASE(structs) +{ + char const* sourceCode = "contract test {\n" + " struct s1 {\n" + " uint8 x;\n" + " bool y;\n" + " }\n" + " struct s2 {\n" + " uint32 z;\n" + " s1 s1data;\n" + " mapping(uint8 => s2) recursive;\n" + " }\n" + " s2 data;\n" + " function check() returns (bool ok) {\n" + " return data.z == 1 && data.s1data.x == 2 && \n" + " data.s1data.y == true && \n" + " data.recursive[3].recursive[4].z == 5 && \n" + " data.recursive[4].recursive[3].z == 6 && \n" + " data.recursive[0].s1data.y == false && \n" + " data.recursive[4].z == 9;\n" + " }\n" + " function set() {\n" + " data.z = 1;\n" + " data.s1data.x = 2;\n" + " data.s1data.y = true;\n" + " data.recursive[3].recursive[4].z = 5;\n" + " data.recursive[4].recursive[3].z = 6;\n" + " data.recursive[0].s1data.y = false;\n" + " data.recursive[4].z = 9;\n" + " }\n" + "}\n"; + compileAndRun(sourceCode); + BOOST_CHECK(callContractFunction(0) == bytes({0x00})); + BOOST_CHECK(callContractFunction(1) == bytes()); + BOOST_CHECK(callContractFunction(0) == bytes({0x01})); +} + BOOST_AUTO_TEST_SUITE_END() } diff --git a/test/solidityNameAndTypeResolution.cpp b/test/solidityNameAndTypeResolution.cpp index f46ad6733..930bba0e3 100644 --- a/test/solidityNameAndTypeResolution.cpp +++ b/test/solidityNameAndTypeResolution.cpp @@ -121,6 +121,44 @@ BOOST_AUTO_TEST_CASE(reference_to_later_declaration) BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text)); } +BOOST_AUTO_TEST_CASE(struct_definition_directly_recursive) +{ + char const* text = "contract test {\n" + " struct MyStructName {\n" + " address addr;\n" + " MyStructName x;\n" + " }\n" + "}\n"; + BOOST_CHECK_THROW(parseTextAndResolveNames(text), ParserError); +} + +BOOST_AUTO_TEST_CASE(struct_definition_indirectly_recursive) +{ + char const* text = "contract test {\n" + " struct MyStructName1 {\n" + " address addr;\n" + " uint256 count;\n" + " MyStructName2 x;\n" + " }\n" + " struct MyStructName2 {\n" + " MyStructName1 x;\n" + " }\n" + "}\n"; + BOOST_CHECK_THROW(parseTextAndResolveNames(text), ParserError); +} + +BOOST_AUTO_TEST_CASE(struct_definition_recursion_via_mapping) +{ + char const* text = "contract test {\n" + " struct MyStructName1 {\n" + " address addr;\n" + " uint256 count;\n" + " mapping(uint => MyStructName1) x;\n" + " }\n" + "}\n"; + BOOST_CHECK_NO_THROW(parseTextAndResolveNames(text)); +} + BOOST_AUTO_TEST_CASE(type_inference_smoke_test) { char const* text = "contract test {\n"