/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/flatbuffers.h" #include <vector> #include "utils/strings/numbers.h" #include "utils/variant.h" namespace libtextclassifier3 { namespace { bool CreateRepeatedField( const reflection::Schema* schema, const reflection::Type* type, std::unique_ptr<ReflectiveFlatbuffer::RepeatedField>* repeated_field) { switch (type->element()) { case reflection::Bool: repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<bool>); return true; case reflection::Int: repeated_field->reset(new ReflectiveFlatbuffer::TypedRepeatedField<int>); return true; case reflection::Long: repeated_field->reset( new ReflectiveFlatbuffer::TypedRepeatedField<int64>); return true; case reflection::Float: repeated_field->reset( new ReflectiveFlatbuffer::TypedRepeatedField<float>); return true; case reflection::Double: repeated_field->reset( new ReflectiveFlatbuffer::TypedRepeatedField<double>); return true; case reflection::String: repeated_field->reset( new ReflectiveFlatbuffer::TypedRepeatedField<std::string>); return true; case reflection::Obj: repeated_field->reset( new ReflectiveFlatbuffer::TypedRepeatedField<ReflectiveFlatbuffer>( schema, type)); return true; default: TC3_LOG(ERROR) << "Unsupported type: " << type->element(); return false; } } } // namespace template <> const char* FlatbufferFileIdentifier<Model>() { return ModelIdentifier(); } std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewRoot() const { if (!schema_->root_table()) { TC3_LOG(ERROR) << "No root table specified."; return nullptr; } return std::unique_ptr<ReflectiveFlatbuffer>( new ReflectiveFlatbuffer(schema_, schema_->root_table())); } std::unique_ptr<ReflectiveFlatbuffer> ReflectiveFlatbufferBuilder::NewTable( StringPiece table_name) const { for (const reflection::Object* object : *schema_->objects()) { if (table_name.Equals(object->name()->str())) { return std::unique_ptr<ReflectiveFlatbuffer>( new ReflectiveFlatbuffer(schema_, object)); } } return nullptr; } const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull( const StringPiece field_name) const { return type_->fields()->LookupByKey(field_name.data()); } const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull( const FlatbufferField* field) const { // Lookup by name might be faster as the fields are sorted by name in the // schema data, so try that first. if (field->field_name() != nullptr) { return GetFieldOrNull(field->field_name()->str()); } return GetFieldByOffsetOrNull(field->field_offset()); } bool ReflectiveFlatbuffer::GetFieldWithParent( const FlatbufferFieldPath* field_path, ReflectiveFlatbuffer** parent, reflection::Field const** field) { const auto* path = field_path->field(); if (path == nullptr || path->size() == 0) { return false; } for (int i = 0; i < path->size(); i++) { *parent = (i == 0 ? this : (*parent)->Mutable(*field)); if (*parent == nullptr) { return false; } *field = (*parent)->GetFieldOrNull(path->Get(i)); if (*field == nullptr) { return false; } } return true; } const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull( const int field_offset) const { if (type_->fields() == nullptr) { return nullptr; } for (const reflection::Field* field : *type_->fields()) { if (field->offset() == field_offset) { return field; } } return nullptr; } bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field, const Variant& value) const { switch (field->type()->base_type()) { case reflection::Bool: return value.HasBool(); case reflection::Int: return value.HasInt(); case reflection::Long: return value.HasInt64(); case reflection::Float: return value.HasFloat(); case reflection::Double: return value.HasDouble(); case reflection::String: return value.HasString(); default: return false; } } bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field, const std::string& value) { switch (field->type()->base_type()) { case reflection::String: return Set(field, value); case reflection::Int: { int32 int_value; if (!ParseInt32(value.data(), &int_value)) { TC3_LOG(ERROR) << "Could not parse '" << value << "' as int32."; return false; } return Set(field, int_value); } case reflection::Long: { int64 int_value; if (!ParseInt64(value.data(), &int_value)) { TC3_LOG(ERROR) << "Could not parse '" << value << "' as int64."; return false; } return Set(field, int_value); } case reflection::Float: { double double_value; if (!ParseDouble(value.data(), &double_value)) { TC3_LOG(ERROR) << "Could not parse '" << value << "' as float."; return false; } return Set(field, static_cast<float>(double_value)); } case reflection::Double: { double double_value; if (!ParseDouble(value.data(), &double_value)) { TC3_LOG(ERROR) << "Could not parse '" << value << "' as double."; return false; } return Set(field, double_value); } default: TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type(); return false; } } bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path, const std::string& value) { ReflectiveFlatbuffer* parent; const reflection::Field* field; if (!GetFieldWithParent(path, &parent, &field)) { return false; } return parent->ParseAndSet(field, value); } ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable( const StringPiece field_name) { if (const reflection::Field* field = GetFieldOrNull(field_name)) { return Mutable(field); } TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString(); return nullptr; } ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable( const reflection::Field* field) { if (field->type()->base_type() != reflection::Obj) { TC3_LOG(ERROR) << "Field is not of type Object."; return nullptr; } const auto entry = children_.find(field); if (entry != children_.end()) { return entry->second.get(); } const auto it = children_.insert( /*hint=*/entry, std::make_pair( field, std::unique_ptr<ReflectiveFlatbuffer>(new ReflectiveFlatbuffer( schema_, schema_->objects()->Get(field->type()->index()))))); return it->second.get(); } ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated( StringPiece field_name) { if (const reflection::Field* field = GetFieldOrNull(field_name)) { return Repeated(field); } TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString(); return nullptr; } ReflectiveFlatbuffer::RepeatedField* ReflectiveFlatbuffer::Repeated( const reflection::Field* field) { if (field->type()->base_type() != reflection::Vector) { TC3_LOG(ERROR) << "Field is not of type Vector."; return nullptr; } // If the repeated field was already set, return its instance. const auto entry = repeated_fields_.find(field); if (entry != repeated_fields_.end()) { return entry->second.get(); } // Otherwise, create a new instance and store it. std::unique_ptr<RepeatedField> repeated_field; if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) { TC3_LOG(ERROR) << "Could not create repeated field."; return nullptr; } const auto it = repeated_fields_.insert( /*hint=*/entry, std::make_pair(field, std::move(repeated_field))); return it->second.get(); } flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize( flatbuffers::FlatBufferBuilder* builder) const { // Build all children before we can start with this table. std::vector< std::pair</* field vtable offset */ int, /* field data offset in buffer */ flatbuffers::uoffset_t>> offsets; offsets.reserve(children_.size() + repeated_fields_.size()); for (const auto& it : children_) { offsets.push_back({it.first->offset(), it.second->Serialize(builder)}); } // Create strings. for (const auto& it : fields_) { if (it.second.HasString()) { offsets.push_back({it.first->offset(), builder->CreateString(it.second.StringValue()).o}); } } // Build the repeated fields. for (const auto& it : repeated_fields_) { offsets.push_back({it.first->offset(), it.second->Serialize(builder)}); } // Build the table now. const flatbuffers::uoffset_t table_start = builder->StartTable(); // Add scalar fields. for (const auto& it : fields_) { switch (it.second.GetType()) { case Variant::TYPE_BOOL_VALUE: builder->AddElement<uint8_t>( it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()), static_cast<uint8_t>(it.first->default_integer())); continue; case Variant::TYPE_INT_VALUE: builder->AddElement<int32>( it.first->offset(), it.second.IntValue(), static_cast<int32>(it.first->default_integer())); continue; case Variant::TYPE_INT64_VALUE: builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(), it.first->default_integer()); continue; case Variant::TYPE_FLOAT_VALUE: builder->AddElement<float>( it.first->offset(), it.second.FloatValue(), static_cast<float>(it.first->default_real())); continue; case Variant::TYPE_DOUBLE_VALUE: builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(), it.first->default_real()); continue; default: continue; } } // Add strings, subtables and repeated fields. for (const auto& it : offsets) { builder->AddOffset(it.first, flatbuffers::Offset<void>(it.second)); } return builder->EndTable(table_start); } std::string ReflectiveFlatbuffer::Serialize() const { flatbuffers::FlatBufferBuilder builder; builder.Finish(flatbuffers::Offset<void>(Serialize(&builder))); return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), builder.GetSize()); } bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) { // No fields to set. if (type_->fields() == nullptr) { return true; } for (const reflection::Field* field : *type_->fields()) { // Skip fields that are not explicitly set. if (!from->CheckField(field->offset())) { continue; } const reflection::BaseType type = field->type()->base_type(); switch (type) { case reflection::Bool: Set<bool>(field, from->GetField<uint8_t>(field->offset(), field->default_integer())); break; case reflection::Int: Set<int32>(field, from->GetField<int32>(field->offset(), field->default_integer())); break; case reflection::Long: Set<int64>(field, from->GetField<int64>(field->offset(), field->default_integer())); break; case reflection::Float: Set<float>(field, from->GetField<float>(field->offset(), field->default_real())); break; case reflection::Double: Set<double>(field, from->GetField<double>(field->offset(), field->default_real())); break; case reflection::String: Set<std::string>( field, from->GetPointer<const flatbuffers::String*>(field->offset()) ->str()); break; case reflection::Obj: if (!Mutable(field)->MergeFrom( from->GetPointer<const flatbuffers::Table* const>( field->offset()))) { return false; } break; default: TC3_LOG(ERROR) << "Unsupported type: " << type; return false; } } return true; } bool ReflectiveFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) { return MergeFrom(flatbuffers::GetAnyRoot( reinterpret_cast<const unsigned char*>(from.data()))); } void ReflectiveFlatbuffer::AsFlatMap( const std::string& key_separator, const std::string& key_prefix, std::map<std::string, Variant>* result) const { // Add direct fields. for (auto it : fields_) { (*result)[key_prefix + it.first->name()->str()] = it.second; } // Add nested messages. for (auto& it : children_) { it.second->AsFlatMap(key_separator, key_prefix + it.first->name()->str() + key_separator, result); } } } // namespace libtextclassifier3