/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/protozero/protoc_plugin/protozero_generator.h"
#include <map>
#include <memory>
#include <set>
#include <string>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/printer.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "google/protobuf/stubs/strutil.h"
namespace protozero {
using google::protobuf::Descriptor; // Message descriptor.
using google::protobuf::EnumDescriptor;
using google::protobuf::EnumValueDescriptor;
using google::protobuf::FieldDescriptor;
using google::protobuf::FileDescriptor;
using google::protobuf::compiler::GeneratorContext;
using google::protobuf::io::Printer;
using google::protobuf::io::ZeroCopyOutputStream;
using google::protobuf::Split;
using google::protobuf::StripPrefixString;
using google::protobuf::StripString;
using google::protobuf::StripSuffixString;
using google::protobuf::UpperString;
namespace {
// Keep this value in sync with ProtoDecoder::kMaxDecoderFieldId. If they go out
// of sync pbzero.h files will stop compiling, hitting the at() static_assert.
// Not worth an extra dependency.
constexpr int kMaxDecoderFieldId = 999;
void Assert(bool condition) {
if (!condition)
__builtin_trap();
}
struct FileDescriptorComp {
bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const {
int comp = lhs->name().compare(rhs->name());
Assert(comp != 0 || lhs == rhs);
return comp < 0;
}
};
struct DescriptorComp {
bool operator()(const Descriptor* lhs, const Descriptor* rhs) const {
int comp = lhs->full_name().compare(rhs->full_name());
Assert(comp != 0 || lhs == rhs);
return comp < 0;
}
};
struct EnumDescriptorComp {
bool operator()(const EnumDescriptor* lhs, const EnumDescriptor* rhs) const {
int comp = lhs->full_name().compare(rhs->full_name());
Assert(comp != 0 || lhs == rhs);
return comp < 0;
}
};
inline std::string ProtoStubName(const FileDescriptor* proto) {
return StripSuffixString(proto->name(), ".proto") + ".pbzero";
}
class GeneratorJob {
public:
GeneratorJob(const FileDescriptor* file, Printer* stub_h_printer)
: source_(file), stub_h_(stub_h_printer) {}
bool GenerateStubs() {
Preprocess();
GeneratePrologue();
for (const EnumDescriptor* enumeration : enums_)
GenerateEnumDescriptor(enumeration);
for (const Descriptor* message : messages_)
GenerateMessageDescriptor(message);
GenerateEpilogue();
return error_.empty();
}
void SetOption(const std::string& name, const std::string& value) {
if (name == "wrapper_namespace") {
wrapper_namespace_ = value;
} else {
Abort(std::string() + "Unknown plugin option '" + name + "'.");
}
}
// If generator fails to produce stubs for a particular proto definitions
// it finishes with undefined output and writes the first error occured.
const std::string& GetFirstError() const { return error_; }
private:
// Only the first error will be recorded.
void Abort(const std::string& reason) {
if (error_.empty())
error_ = reason;
}
// Get full name (including outer descriptors) of proto descriptor.
template <class T>
inline std::string GetDescriptorName(const T* descriptor) {
if (!package_.empty()) {
return StripPrefixString(descriptor->full_name(), package_ + ".");
} else {
return descriptor->full_name();
}
}
// Get C++ class name corresponding to proto descriptor.
// Nested names are splitted by underscores. Underscores in type names aren't
// prohibited but not recommended in order to avoid name collisions.
template <class T>
inline std::string GetCppClassName(const T* descriptor, bool full = false) {
std::string name = GetDescriptorName(descriptor);
StripString(&name, ".", '_');
if (full)
name = full_namespace_prefix_ + name;
return name;
}
inline std::string GetFieldNumberConstant(const FieldDescriptor* field) {
std::string name = field->camelcase_name();
if (!name.empty()) {
name.at(0) = static_cast<char>(toupper(name.at(0)));
name = "k" + name + "FieldNumber";
} else {
// Protoc allows fields like 'bool _ = 1'.
Abort("Empty field name in camel case notation.");
}
return name;
}
// Small enums can be written faster without involving VarInt encoder.
inline bool IsTinyEnumField(const FieldDescriptor* field) {
if (field->type() != FieldDescriptor::TYPE_ENUM)
return false;
const EnumDescriptor* enumeration = field->enum_type();
for (int i = 0; i < enumeration->value_count(); ++i) {
int32_t value = enumeration->value(i)->number();
if (value < 0 || value > 0x7F)
return false;
}
return true;
}
void CollectDescriptors() {
// Collect message descriptors in DFS order.
std::vector<const Descriptor*> stack;
for (int i = 0; i < source_->message_type_count(); ++i)
stack.push_back(source_->message_type(i));
while (!stack.empty()) {
const Descriptor* message = stack.back();
stack.pop_back();
messages_.push_back(message);
for (int i = 0; i < message->nested_type_count(); ++i) {
stack.push_back(message->nested_type(i));
}
}
// Collect enums.
for (int i = 0; i < source_->enum_type_count(); ++i)
enums_.push_back(source_->enum_type(i));
for (const Descriptor* message : messages_) {
for (int i = 0; i < message->enum_type_count(); ++i) {
enums_.push_back(message->enum_type(i));
}
}
}
void CollectDependencies() {
// Public import basically means that callers only need to import this
// proto in order to use the stuff publicly imported by this proto.
for (int i = 0; i < source_->public_dependency_count(); ++i)
public_imports_.insert(source_->public_dependency(i));
if (source_->weak_dependency_count() > 0)
Abort("Weak imports are not supported.");
// Sanity check. Collect public imports (of collected imports) in DFS order.
// Visibilty for current proto:
// - all imports listed in current proto,
// - public imports of everything imported (recursive).
std::vector<const FileDescriptor*> stack;
for (int i = 0; i < source_->dependency_count(); ++i) {
const FileDescriptor* import = source_->dependency(i);
stack.push_back(import);
if (public_imports_.count(import) == 0) {
private_imports_.insert(import);
}
}
while (!stack.empty()) {
const FileDescriptor* import = stack.back();
stack.pop_back();
// Having imports under different packages leads to unnecessary
// complexity with namespaces.
if (import->package() != package_)
Abort("Imported proto must be in the same package.");
for (int i = 0; i < import->public_dependency_count(); ++i) {
stack.push_back(import->public_dependency(i));
}
}
// Collect descriptors of messages and enums used in current proto.
// It will be used to generate necessary forward declarations and performed
// sanity check guarantees that everything lays in the same namespace.
for (const Descriptor* message : messages_) {
for (int i = 0; i < message->field_count(); ++i) {
const FieldDescriptor* field = message->field(i);
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
if (public_imports_.count(field->message_type()->file()) == 0) {
// Avoid multiple forward declarations since
// public imports have been already included.
referenced_messages_.insert(field->message_type());
}
} else if (field->type() == FieldDescriptor::TYPE_ENUM) {
if (public_imports_.count(field->enum_type()->file()) == 0) {
referenced_enums_.insert(field->enum_type());
}
}
}
}
}
void Preprocess() {
// Package name maps to a series of namespaces.
package_ = source_->package();
namespaces_ = Split(package_, ".");
if (!wrapper_namespace_.empty())
namespaces_.push_back(wrapper_namespace_);
full_namespace_prefix_ = "::";
for (const std::string& ns : namespaces_)
full_namespace_prefix_ += ns + "::";
CollectDescriptors();
CollectDependencies();
}
// Print top header, namespaces and forward declarations.
void GeneratePrologue() {
std::string greeting =
"// Autogenerated by the ProtoZero compiler plugin. DO NOT EDIT.\n";
std::string guard = package_ + "_" + source_->name() + "_H_";
UpperString(&guard);
StripString(&guard, ".-/\\", '_');
stub_h_->Print(
"$greeting$\n"
"#ifndef $guard$\n"
"#define $guard$\n\n"
"#include <stddef.h>\n"
"#include <stdint.h>\n\n"
"#include \"perfetto/base/export.h\"\n"
"#include \"perfetto/protozero/proto_decoder.h\"\n"
"#include \"perfetto/protozero/message.h\"\n",
"greeting", greeting, "guard", guard);
// Print includes for public imports.
for (const FileDescriptor* dependency : public_imports_) {
// Dependency name could contain slashes but importing from upper-level
// directories is not possible anyway since build system processes each
// proto file individually. Hence proto lookup path is always equal to the
// directory where particular proto file is located and protoc does not
// allow reference to upper directory (aka ..) in import path.
//
// Laconically said:
// - source_->name() may never have slashes,
// - dependency->name() may have slashes but always refers to inner path.
stub_h_->Print("#include \"$name$.h\"\n", "name",
ProtoStubName(dependency));
}
stub_h_->Print("\n");
// Print namespaces.
for (const std::string& ns : namespaces_) {
stub_h_->Print("namespace $ns$ {\n", "ns", ns);
}
stub_h_->Print("\n");
// Print forward declarations.
for (const Descriptor* message : referenced_messages_) {
stub_h_->Print("class $class$;\n", "class", GetCppClassName(message));
}
for (const EnumDescriptor* enumeration : referenced_enums_) {
stub_h_->Print("enum $class$ : int32_t;\n", "class",
GetCppClassName(enumeration));
}
stub_h_->Print("\n");
}
void GenerateEnumDescriptor(const EnumDescriptor* enumeration) {
stub_h_->Print("enum $class$ : int32_t {\n", "class",
GetCppClassName(enumeration));
stub_h_->Indent();
std::string value_name_prefix;
if (enumeration->containing_type() != nullptr)
value_name_prefix = GetCppClassName(enumeration) + "_";
for (int i = 0; i < enumeration->value_count(); ++i) {
const EnumValueDescriptor* value = enumeration->value(i);
stub_h_->Print("$name$ = $number$,\n", "name",
value_name_prefix + value->name(), "number",
std::to_string(value->number()));
}
stub_h_->Outdent();
stub_h_->Print("};\n\n");
}
void GenerateSimpleFieldDescriptor(const FieldDescriptor* field) {
std::map<std::string, std::string> setter;
setter["id"] = std::to_string(field->number());
setter["name"] = field->name();
setter["action"] = field->is_repeated() ? "add" : "set";
std::string appender;
std::string cpp_type;
switch (field->type()) {
case FieldDescriptor::TYPE_BOOL: {
appender = "AppendTinyVarInt";
cpp_type = "bool";
break;
}
case FieldDescriptor::TYPE_INT32: {
appender = "AppendVarInt";
cpp_type = "int32_t";
break;
}
case FieldDescriptor::TYPE_INT64: {
appender = "AppendVarInt";
cpp_type = "int64_t";
break;
}
case FieldDescriptor::TYPE_UINT32: {
appender = "AppendVarInt";
cpp_type = "uint32_t";
break;
}
case FieldDescriptor::TYPE_UINT64: {
appender = "AppendVarInt";
cpp_type = "uint64_t";
break;
}
case FieldDescriptor::TYPE_SINT32: {
appender = "AppendSignedVarInt";
cpp_type = "int32_t";
break;
}
case FieldDescriptor::TYPE_SINT64: {
appender = "AppendSignedVarInt";
cpp_type = "int64_t";
break;
}
case FieldDescriptor::TYPE_FIXED32: {
appender = "AppendFixed";
cpp_type = "uint32_t";
break;
}
case FieldDescriptor::TYPE_FIXED64: {
appender = "AppendFixed";
cpp_type = "uint64_t";
break;
}
case FieldDescriptor::TYPE_SFIXED32: {
appender = "AppendFixed";
cpp_type = "int32_t";
break;
}
case FieldDescriptor::TYPE_SFIXED64: {
appender = "AppendFixed";
cpp_type = "int64_t";
break;
}
case FieldDescriptor::TYPE_FLOAT: {
appender = "AppendFixed";
cpp_type = "float";
break;
}
case FieldDescriptor::TYPE_DOUBLE: {
appender = "AppendFixed";
cpp_type = "double";
break;
}
case FieldDescriptor::TYPE_ENUM: {
appender = IsTinyEnumField(field) ? "AppendTinyVarInt" : "AppendVarInt";
cpp_type = GetCppClassName(field->enum_type(), true);
break;
}
case FieldDescriptor::TYPE_STRING: {
appender = "AppendString";
cpp_type = "const char*";
break;
}
case FieldDescriptor::TYPE_BYTES: {
stub_h_->Print(
setter,
"void $action$_$name$(const uint8_t* data, size_t size) {\n"
" AppendBytes($id$, data, size);\n"
"}\n");
return;
}
case FieldDescriptor::TYPE_GROUP:
case FieldDescriptor::TYPE_MESSAGE: {
Abort("Unsupported field type.");
return;
}
}
setter["appender"] = appender;
setter["cpp_type"] = cpp_type;
stub_h_->Print(setter,
"void $action$_$name$($cpp_type$ value) {\n"
" $appender$($id$, value);\n"
"}\n");
// For strings also generate a variant for non-null terminated strings.
if (field->type() == FieldDescriptor::TYPE_STRING) {
stub_h_->Print(setter,
"// Doesn't check for null terminator.\n"
"// Expects |value| to be at least |size| long.\n"
"void $action$_$name$($cpp_type$ value, size_t size) {\n"
" AppendBytes($id$, value, size);\n"
"}\n");
}
}
void GenerateNestedMessageFieldDescriptor(const FieldDescriptor* field) {
std::string action = field->is_repeated() ? "add" : "set";
std::string inner_class = GetCppClassName(field->message_type());
stub_h_->Print(
"template <typename T = $inner_class$> T* $action$_$name$() {\n"
" return BeginNestedMessage<T>($id$);\n"
"}\n\n",
"id", std::to_string(field->number()), "name", field->name(), "action",
action, "inner_class", inner_class);
}
void GenerateDecoder(const Descriptor* message) {
int max_field_id = 0;
bool has_repeated_fields = false;
for (int i = 0; i < message->field_count(); ++i) {
const FieldDescriptor* field = message->field(i);
if (field->number() > kMaxDecoderFieldId)
continue;
max_field_id = std::max(max_field_id, field->number());
if (field->is_repeated())
has_repeated_fields = true;
}
stub_h_->Print(
"class Decoder : public "
"::protozero::TypedProtoDecoder</*MAX_FIELD_ID=*/$max$, "
"/*HAS_REPEATED_FIELDS=*/$rep$> {\n",
"max", std::to_string(max_field_id), "rep",
has_repeated_fields ? "true" : "false");
stub_h_->Print(" public:\n");
stub_h_->Indent();
stub_h_->Print(
"Decoder(const uint8_t* data, size_t len) "
": TypedProtoDecoder(data, len) {}\n");
for (int i = 0; i < message->field_count(); ++i) {
const FieldDescriptor* field = message->field(i);
if (field->is_packed()) {
Abort("Packed repeated fields are not supported.");
return;
}
if (field->number() > max_field_id) {
stub_h_->Print("// field $name$ omitted because its id is too high\n",
"name", field->name());
continue;
}
std::string getter;
std::string cpp_type;
switch (field->type()) {
case FieldDescriptor::TYPE_BOOL:
getter = "as_bool";
cpp_type = "bool";
break;
case FieldDescriptor::TYPE_SFIXED32:
case FieldDescriptor::TYPE_SINT32:
case FieldDescriptor::TYPE_INT32:
getter = "as_int32";
cpp_type = "int32_t";
break;
case FieldDescriptor::TYPE_SFIXED64:
case FieldDescriptor::TYPE_SINT64:
case FieldDescriptor::TYPE_INT64:
getter = "as_int64";
cpp_type = "int64_t";
break;
case FieldDescriptor::TYPE_FIXED32:
case FieldDescriptor::TYPE_UINT32:
getter = "as_uint32";
cpp_type = "uint32_t";
break;
case FieldDescriptor::TYPE_FIXED64:
case FieldDescriptor::TYPE_UINT64:
getter = "as_uint64";
cpp_type = "uint64_t";
break;
case FieldDescriptor::TYPE_FLOAT:
getter = "as_float";
cpp_type = "float";
break;
case FieldDescriptor::TYPE_DOUBLE:
getter = "as_double";
cpp_type = "double";
break;
case FieldDescriptor::TYPE_ENUM:
getter = "as_int32";
cpp_type = "int32_t";
break;
case FieldDescriptor::TYPE_STRING:
getter = "as_string";
cpp_type = "::protozero::ConstChars";
break;
case FieldDescriptor::TYPE_MESSAGE:
case FieldDescriptor::TYPE_BYTES:
getter = "as_bytes";
cpp_type = "::protozero::ConstBytes";
break;
case FieldDescriptor::TYPE_GROUP:
continue;
}
stub_h_->Print("bool has_$name$() const { return at<$id$>().valid(); }\n",
"name", field->name(), "id",
std::to_string(field->number()));
if (field->is_repeated()) {
stub_h_->Print(
"::protozero::RepeatedFieldIterator $name$() const { return "
"GetRepeated($id$); }\n",
"name", field->name(), "id", std::to_string(field->number()));
} else {
stub_h_->Print(
"$cpp_type$ $name$() const { return at<$id$>().$getter$(); }\n",
"name", field->name(), "id", std::to_string(field->number()),
"cpp_type", cpp_type, "getter", getter);
}
}
stub_h_->Outdent();
stub_h_->Print("};\n");
}
void GenerateConstantsForMessageFields(const Descriptor* message) {
const bool has_fields = (message->field_count() > 0);
// Field number constants.
if (has_fields) {
stub_h_->Print("enum : int32_t {\n");
stub_h_->Indent();
for (int i = 0; i < message->field_count(); ++i) {
const FieldDescriptor* field = message->field(i);
stub_h_->Print("$name$ = $id$,\n", "name",
GetFieldNumberConstant(field), "id",
std::to_string(field->number()));
}
stub_h_->Outdent();
stub_h_->Print("};\n");
}
}
void GenerateMessageDescriptor(const Descriptor* message) {
stub_h_->Print(
"class PERFETTO_EXPORT $name$ : public ::protozero::Message {\n"
" public:\n",
"name", GetCppClassName(message));
stub_h_->Indent();
GenerateConstantsForMessageFields(message);
GenerateDecoder(message);
// Using statements for nested messages.
for (int i = 0; i < message->nested_type_count(); ++i) {
const Descriptor* nested_message = message->nested_type(i);
stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
nested_message->name(), "global_name",
GetCppClassName(nested_message, true));
}
// Using statements for nested enums.
for (int i = 0; i < message->enum_type_count(); ++i) {
const EnumDescriptor* nested_enum = message->enum_type(i);
stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
nested_enum->name(), "global_name",
GetCppClassName(nested_enum, true));
}
// Values of nested enums.
for (int i = 0; i < message->enum_type_count(); ++i) {
const EnumDescriptor* nested_enum = message->enum_type(i);
std::string value_name_prefix = GetCppClassName(nested_enum) + "_";
for (int j = 0; j < nested_enum->value_count(); ++j) {
const EnumValueDescriptor* value = nested_enum->value(j);
stub_h_->Print("static const $class$ $name$ = $full_name$;\n", "class",
nested_enum->name(), "name", value->name(), "full_name",
value_name_prefix + value->name());
}
}
// Field descriptors.
for (int i = 0; i < message->field_count(); ++i) {
const FieldDescriptor* field = message->field(i);
if (field->is_packed()) {
Abort("Packed repeated fields are not supported.");
return;
}
if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
GenerateSimpleFieldDescriptor(field);
} else {
GenerateNestedMessageFieldDescriptor(field);
}
}
stub_h_->Outdent();
stub_h_->Print("};\n\n");
}
void GenerateEpilogue() {
for (unsigned i = 0; i < namespaces_.size(); ++i) {
stub_h_->Print("} // Namespace.\n");
}
stub_h_->Print("#endif // Include guard.\n");
}
const FileDescriptor* const source_;
Printer* const stub_h_;
std::string error_;
std::string package_;
std::string wrapper_namespace_;
std::vector<std::string> namespaces_;
std::string full_namespace_prefix_;
std::vector<const Descriptor*> messages_;
std::vector<const EnumDescriptor*> enums_;
// The custom *Comp comparators are to ensure determinism of the generator.
std::set<const FileDescriptor*, FileDescriptorComp> public_imports_;
std::set<const FileDescriptor*, FileDescriptorComp> private_imports_;
std::set<const Descriptor*, DescriptorComp> referenced_messages_;
std::set<const EnumDescriptor*, EnumDescriptorComp> referenced_enums_;
};
} // namespace
ProtoZeroGenerator::ProtoZeroGenerator() {}
ProtoZeroGenerator::~ProtoZeroGenerator() {}
bool ProtoZeroGenerator::Generate(const FileDescriptor* file,
const std::string& options,
GeneratorContext* context,
std::string* error) const {
const std::unique_ptr<ZeroCopyOutputStream> stub_h_file_stream(
context->Open(ProtoStubName(file) + ".h"));
const std::unique_ptr<ZeroCopyOutputStream> stub_cc_file_stream(
context->Open(ProtoStubName(file) + ".cc"));
// Variables are delimited by $.
Printer stub_h_printer(stub_h_file_stream.get(), '$');
GeneratorJob job(file, &stub_h_printer);
Printer stub_cc_printer(stub_cc_file_stream.get(), '$');
stub_cc_printer.Print("// Intentionally empty\n");
// Parse additional options.
for (const std::string& option : Split(options, ",")) {
std::vector<std::string> option_pair = Split(option, "=");
job.SetOption(option_pair[0], option_pair[1]);
}
if (!job.GenerateStubs()) {
*error = job.GetFirstError();
return false;
}
return true;
}
} // namespace protozero