Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 149 additions & 34 deletions patch/patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ import (
"golang.org/x/tools/go/ast/astutil"
"google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/pluginpb"

"github.com/golang/protobuf/proto"

"github.com/alta/protopatch/lint"
"github.com/alta/protopatch/patch/gopb"
"github.com/alta/protopatch/patch/ident"
)

Expand All @@ -33,56 +37,166 @@ import (
// - (go.enum).name overrides the name of an enum type.
// - (go.value).name overrides the name of an enum value.
type Patcher struct {
gen *protogen.Plugin
fset *token.FileSet
filesByName map[string]*ast.File
info *types.Info
packages []*Package
packagesByPath map[string]*Package
packagesByName map[string]*Package
renames map[protogen.GoIdent]string
typeRenames map[protogen.GoIdent]string
valueRenames map[protogen.GoIdent]string
fieldRenames map[protogen.GoIdent]string
methodRenames map[protogen.GoIdent]string
objectRenames map[types.Object]string
tags map[protogen.GoIdent]string
fieldTags map[types.Object]string
embeds map[protogen.GoIdent]string
fieldEmbeds map[types.Object]string
types map[protogen.GoIdent]string
fieldTypes map[types.Object]string
gen *protogen.Plugin
fset *token.FileSet
filesByName map[string]*ast.File
info *types.Info
packages []*Package
packagesByPath map[string]*Package
packagesByName map[string]*Package
renames map[protogen.GoIdent]string
typeRenames map[protogen.GoIdent]string
valueRenames map[protogen.GoIdent]string
fieldRenames map[protogen.GoIdent]string
methodRenames map[protogen.GoIdent]string
objectRenames map[types.Object]string
tags map[protogen.GoIdent]string
fieldTags map[types.Object]string
embeds map[protogen.GoIdent]string
fieldEmbeds map[types.Object]string
types map[protogen.GoIdent]string
fieldTypes map[types.Object]string
processedMessages map[protogen.GoIdent]bool
}

// NewPatcher returns an initialized Patcher for gen.
func NewPatcher(gen *protogen.Plugin) (*Patcher, error) {
p := &Patcher{
gen: gen,
packagesByPath: make(map[string]*Package),
packagesByName: make(map[string]*Package),
renames: make(map[protogen.GoIdent]string),
typeRenames: make(map[protogen.GoIdent]string),
valueRenames: make(map[protogen.GoIdent]string),
fieldRenames: make(map[protogen.GoIdent]string),
methodRenames: make(map[protogen.GoIdent]string),
objectRenames: make(map[types.Object]string),
tags: make(map[protogen.GoIdent]string),
fieldTags: make(map[types.Object]string),
embeds: make(map[protogen.GoIdent]string),
fieldEmbeds: make(map[types.Object]string),
types: make(map[protogen.GoIdent]string),
fieldTypes: make(map[types.Object]string),
gen: gen,
packagesByPath: make(map[string]*Package),
packagesByName: make(map[string]*Package),
renames: make(map[protogen.GoIdent]string),
typeRenames: make(map[protogen.GoIdent]string),
valueRenames: make(map[protogen.GoIdent]string),
fieldRenames: make(map[protogen.GoIdent]string),
methodRenames: make(map[protogen.GoIdent]string),
objectRenames: make(map[types.Object]string),
tags: make(map[protogen.GoIdent]string),
fieldTags: make(map[types.Object]string),
embeds: make(map[protogen.GoIdent]string),
fieldEmbeds: make(map[types.Object]string),
types: make(map[protogen.GoIdent]string),
fieldTypes: make(map[types.Object]string),
processedMessages: make(map[protogen.GoIdent]bool),
}
return p, p.scan()
}

func getExtensionDesc(pb proto.Message, extname string) (*proto.ExtensionDesc, error) {
desc := proto.RegisteredExtensions(pb)
for _, d := range desc {
if d.Name == extname {
return d, nil
}
}
return nil, fmt.Errorf("ExtensionDesc not found")
}

func getExtension(pb proto.Message, extname string) (interface{}, error) {
d, err := getExtensionDesc(pb, extname)
if err != nil {
return nil, err
}
e, err := proto.GetExtension(pb, d)
if err != nil {
return nil, err
}
return e, err
}

func (p *Patcher) scan() error {
for _, f := range p.gen.Files {
p.scanFile(f)
}
for _, f := range p.gen.Request.ProtoFile {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this in a new scanRequest method?

found := false
mident := protogen.GoIdent{GoName: "", GoImportPath: ""}
fident := protogen.GoIdent{GoName: "", GoImportPath: ""}
for _, genFile := range p.gen.Files {
if *f.Name == genFile.Desc.Path() {
found = true
mident = protogen.GoIdent{GoName: "", GoImportPath: genFile.GoImportPath}
fident = protogen.GoIdent{GoName: "", GoImportPath: genFile.GoImportPath}
break
}
}
if !found {
panic("Not found")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this return an error instead?

}
for _, m := range f.MessageType {
mident.GoName = *m.Name
if _, ok := p.processedMessages[mident]; ok {
continue
}
nmident := protogen.GoIdent{GoName: "", GoImportPath: mident.GoImportPath}
nfident := protogen.GoIdent{GoName: "", GoImportPath: mident.GoImportPath}
for _, nestedMsgType := range m.NestedType {
nmident.GoName = mident.GoName + "_" + *nestedMsgType.Name
for _, msgfield := range nestedMsgType.Field {
nfident.GoName = *msgfield.Name
p.scanProtoField(nmident, nfident, msgfield)
}
}

for _, msgfield := range m.Field {
fident.GoName = *msgfield.Name
p.scanProtoField(mident, fident, msgfield)
}
}
}

return nil
}

func (p *Patcher) scanProtoField(mident protogen.GoIdent, fident protogen.GoIdent, f *descriptorpb.FieldDescriptorProto) {
//m := f.Parent
//o := f.Oneof

if f.TypeName == nil {
log.Printf("Typename not set")
return
}
fi, err := getExtension(f.GetOptions(), "go.field")
if err != nil {
log.Printf("go.field extension not found", err)
return
}
opts := fi.(*gopb.Options)

log.Printf("Parent Message %v (%v), opts %v", *f.Name, *f.TypeName, opts)
// Embed field ?
embed := false
newName := ""
if opts.GetEmbed() {
switch {
default:
embed = true
log.Printf("Embed Set %v ", *f.Name, *f.TypeName)
splitStrings := strings.Split((*f.TypeName)[1:], ".")
newName = splitStrings[len(splitStrings)-1]
}
}
if newName != "" {
if false {
p.RenameType(fident, p.nameFor(mident)+"_"+newName) // Oneof wrapper struct
p.RenameField(ident.WithChild(fident, fident.GoName), newName, false) // Oneof wrapper field (not embeddable)
} else {
p.RenameField(ident.WithChild(mident, fident.GoName), newName, embed) // Field
childID := ident.WithChild(mident, fident.GoName)
log.Printf("child %v parent %v", childID, mident.GoName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these debugging logs?

}
p.RenameMethod(ident.WithChild(mident, "Get"+fident.GoName), "Get"+newName) // Getter
} else {
p.RenameField(ident.WithChild(mident, fident.GoName), newName, embed) // Field
}

tags := opts.GetTags()
if tags != "" {
log.Printf("Tags Set identifier %v %v %v", ident.WithChild(mident, fident.GoName), *f.Name, *f.TypeName, tags)
p.Tag(ident.WithChild(mident, fident.GoName), tags) // Field tags
}
}

func (p *Patcher) scanFile(f *protogen.File) {
log.Printf("\nScan proto:\t%s", f.Desc.Path())

Expand Down Expand Up @@ -190,6 +304,7 @@ func (p *Patcher) scanMessage(m *protogen.Message, parent *protogen.Message) {
opts := messageOptions(m)
lints := fileLintOptions(m.Desc)

p.processedMessages[m.GoIdent] = true
// Rename message?
newName := opts.GetName()
if newName == "" && parent != nil && p.isRenamed(parent.GoIdent) {
Expand Down