// Protocol Buffers for Go with Gadgets
//
// Copyright (c) 2013, The GoGo Authors. All rights reserved.
// http://github.com/gogo/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
//     * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//     * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

package compare

import (
	"github.com/gogo/protobuf/gogoproto"
	"github.com/gogo/protobuf/proto"
	descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
	"github.com/gogo/protobuf/protoc-gen-gogo/generator"
	"github.com/gogo/protobuf/vanity"
)

type plugin struct {
	*generator.Generator
	generator.PluginImports
	fmtPkg      generator.Single
	bytesPkg    generator.Single
	sortkeysPkg generator.Single
	protoPkg    generator.Single
}

func NewPlugin() *plugin {
	return &plugin{}
}

func (p *plugin) Name() string {
	return "compare"
}

func (p *plugin) Init(g *generator.Generator) {
	p.Generator = g
}

func (p *plugin) Generate(file *generator.FileDescriptor) {
	p.PluginImports = generator.NewPluginImports(p.Generator)
	p.fmtPkg = p.NewImport("fmt")
	p.bytesPkg = p.NewImport("bytes")
	p.sortkeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys")
	p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto")

	for _, msg := range file.Messages() {
		if msg.DescriptorProto.GetOptions().GetMapEntry() {
			continue
		}
		if gogoproto.HasCompare(file.FileDescriptorProto, msg.DescriptorProto) {
			p.generateMessage(file, msg)
		}
	}
}

func (p *plugin) generateNullableField(fieldname string) {
	p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
	p.In()
	p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
	p.In()
	p.P(`if *this.`, fieldname, ` < *that1.`, fieldname, `{`)
	p.In()
	p.P(`return -1`)
	p.Out()
	p.P(`}`)
	p.P(`return 1`)
	p.Out()
	p.P(`}`)
	p.Out()
	p.P(`} else if this.`, fieldname, ` != nil {`)
	p.In()
	p.P(`return 1`)
	p.Out()
	p.P(`} else if that1.`, fieldname, ` != nil {`)
	p.In()
	p.P(`return -1`)
	p.Out()
	p.P(`}`)
}

func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string) {
	p.P(`if that == nil {`)
	p.In()
	p.P(`if this == nil {`)
	p.In()
	p.P(`return 0`)
	p.Out()
	p.P(`}`)
	p.P(`return 1`)
	p.Out()
	p.P(`}`)
	p.P(``)
	p.P(`that1, ok := that.(*`, ccTypeName, `)`)
	p.P(`if !ok {`)
	p.In()
	p.P(`that2, ok := that.(`, ccTypeName, `)`)
	p.P(`if ok {`)
	p.In()
	p.P(`that1 = &that2`)
	p.Out()
	p.P(`} else {`)
	p.In()
	p.P(`return 1`)
	p.Out()
	p.P(`}`)
	p.Out()
	p.P(`}`)
	p.P(`if that1 == nil {`)
	p.In()
	p.P(`if this == nil {`)
	p.In()
	p.P(`return 0`)
	p.Out()
	p.P(`}`)
	p.P(`return 1`)
	p.Out()
	p.P(`} else if this == nil {`)
	p.In()
	p.P(`return -1`)
	p.Out()
	p.P(`}`)
}

func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) {
	proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
	fieldname := p.GetOneOfFieldName(message, field)
	repeated := field.IsRepeated()
	ctype := gogoproto.IsCustomType(field)
	nullable := gogoproto.IsNullable(field)
	// oneof := field.OneofIndex != nil
	if !repeated {
		if ctype {
			if nullable {
				p.P(`if that1.`, fieldname, ` == nil {`)
				p.In()
				p.P(`if this.`, fieldname, ` != nil {`)
				p.In()
				p.P(`return 1`)
				p.Out()
				p.P(`}`)
				p.Out()
				p.P(`} else if this.`, fieldname, ` == nil {`)
				p.In()
				p.P(`return -1`)
				p.Out()
				p.P(`} else if c := this.`, fieldname, `.Compare(*that1.`, fieldname, `); c != 0 {`)
			} else {
				p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
			}
			p.In()
			p.P(`return c`)
			p.Out()
			p.P(`}`)
		} else {
			if field.IsMessage() || p.IsGroup(field) {
				if nullable {
					p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
				} else {
					p.P(`if c := this.`, fieldname, `.Compare(&that1.`, fieldname, `); c != 0 {`)
				}
				p.In()
				p.P(`return c`)
				p.Out()
				p.P(`}`)
			} else if field.IsBytes() {
				p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
				p.In()
				p.P(`return c`)
				p.Out()
				p.P(`}`)
			} else if field.IsString() {
				if nullable && !proto3 {
					p.generateNullableField(fieldname)
				} else {
					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
					p.In()
					p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
					p.In()
					p.P(`return -1`)
					p.Out()
					p.P(`}`)
					p.P(`return 1`)
					p.Out()
					p.P(`}`)
				}
			} else if field.IsBool() {
				if nullable && !proto3 {
					p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
					p.In()
					p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
					p.In()
					p.P(`if !*this.`, fieldname, ` {`)
					p.In()
					p.P(`return -1`)
					p.Out()
					p.P(`}`)
					p.P(`return 1`)
					p.Out()
					p.P(`}`)
					p.Out()
					p.P(`} else if this.`, fieldname, ` != nil {`)
					p.In()
					p.P(`return 1`)
					p.Out()
					p.P(`} else if that1.`, fieldname, ` != nil {`)
					p.In()
					p.P(`return -1`)
					p.Out()
					p.P(`}`)
				} else {
					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
					p.In()
					p.P(`if !this.`, fieldname, ` {`)
					p.In()
					p.P(`return -1`)
					p.Out()
					p.P(`}`)
					p.P(`return 1`)
					p.Out()
					p.P(`}`)
				}
			} else {
				if nullable && !proto3 {
					p.generateNullableField(fieldname)
				} else {
					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
					p.In()
					p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
					p.In()
					p.P(`return -1`)
					p.Out()
					p.P(`}`)
					p.P(`return 1`)
					p.Out()
					p.P(`}`)
				}
			}
		}
	} else {
		p.P(`if len(this.`, fieldname, `) != len(that1.`, fieldname, `) {`)
		p.In()
		p.P(`if len(this.`, fieldname, `) < len(that1.`, fieldname, `) {`)
		p.In()
		p.P(`return -1`)
		p.Out()
		p.P(`}`)
		p.P(`return 1`)
		p.Out()
		p.P(`}`)
		p.P(`for i := range this.`, fieldname, ` {`)
		p.In()
		if ctype {
			p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
			p.In()
			p.P(`return c`)
			p.Out()
			p.P(`}`)
		} else {
			if p.IsMap(field) {
				m := p.GoMapType(nil, field)
				valuegoTyp, _ := p.GoType(nil, m.ValueField)
				valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField)
				nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)

				mapValue := m.ValueAliasField
				if mapValue.IsMessage() || p.IsGroup(mapValue) {
					if nullable && valuegoTyp == valuegoAliasTyp {
						p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
					} else {
						// Compare() has a pointer receiver, but map value is a value type
						a := `this.` + fieldname + `[i]`
						b := `that1.` + fieldname + `[i]`
						if valuegoTyp != valuegoAliasTyp {
							// cast back to the type that has the generated methods on it
							a = `(` + valuegoTyp + `)(` + a + `)`
							b = `(` + valuegoTyp + `)(` + b + `)`
						}
						p.P(`a := `, a)
						p.P(`b := `, b)
						if nullable {
							p.P(`if c := a.Compare(b); c != 0 {`)
						} else {
							p.P(`if c := (&a).Compare(&b); c != 0 {`)
						}
					}
					p.In()
					p.P(`return c`)
					p.Out()
					p.P(`}`)
				} else if mapValue.IsBytes() {
					p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
					p.In()
					p.P(`return c`)
					p.Out()
					p.P(`}`)
				} else if mapValue.IsString() {
					p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
					p.In()
					p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
					p.In()
					p.P(`return -1`)
					p.Out()
					p.P(`}`)
					p.P(`return 1`)
					p.Out()
					p.P(`}`)
				} else {
					p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
					p.In()
					p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
					p.In()
					p.P(`return -1`)
					p.Out()
					p.P(`}`)
					p.P(`return 1`)
					p.Out()
					p.P(`}`)
				}
			} else if field.IsMessage() || p.IsGroup(field) {
				if nullable {
					p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
					p.In()
					p.P(`return c`)
					p.Out()
					p.P(`}`)
				} else {
					p.P(`if c := this.`, fieldname, `[i].Compare(&that1.`, fieldname, `[i]); c != 0 {`)
					p.In()
					p.P(`return c`)
					p.Out()
					p.P(`}`)
				}
			} else if field.IsBytes() {
				p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
				p.In()
				p.P(`return c`)
				p.Out()
				p.P(`}`)
			} else if field.IsString() {
				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
				p.In()
				p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
				p.In()
				p.P(`return -1`)
				p.Out()
				p.P(`}`)
				p.P(`return 1`)
				p.Out()
				p.P(`}`)
			} else if field.IsBool() {
				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
				p.In()
				p.P(`if !this.`, fieldname, `[i] {`)
				p.In()
				p.P(`return -1`)
				p.Out()
				p.P(`}`)
				p.P(`return 1`)
				p.Out()
				p.P(`}`)
			} else {
				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
				p.In()
				p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
				p.In()
				p.P(`return -1`)
				p.Out()
				p.P(`}`)
				p.P(`return 1`)
				p.Out()
				p.P(`}`)
			}
		}
		p.Out()
		p.P(`}`)
	}
}

func (p *plugin) generateMessage(file *generator.FileDescriptor, message *generator.Descriptor) {
	ccTypeName := generator.CamelCaseSlice(message.TypeName())
	p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
	p.In()
	p.generateMsgNullAndTypeCheck(ccTypeName)
	oneofs := make(map[string]struct{})

	for _, field := range message.Field {
		oneof := field.OneofIndex != nil
		if oneof {
			fieldname := p.GetFieldName(message, field)
			if _, ok := oneofs[fieldname]; ok {
				continue
			} else {
				oneofs[fieldname] = struct{}{}
			}
			p.P(`if that1.`, fieldname, ` == nil {`)
			p.In()
			p.P(`if this.`, fieldname, ` != nil {`)
			p.In()
			p.P(`return 1`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`} else if this.`, fieldname, ` == nil {`)
			p.In()
			p.P(`return -1`)
			p.Out()
			p.P(`} else {`)
			p.In()

			// Generate two type switches in order to compare the
			// types of the oneofs. If they are of the same type
			// call Compare, otherwise return 1 or -1.
			p.P(`thisType := -1`)
			p.P(`switch this.`, fieldname, `.(type) {`)
			for i, subfield := range message.Field {
				if *subfield.OneofIndex == *field.OneofIndex {
					ccTypeName := p.OneOfTypeName(message, subfield)
					p.P(`case *`, ccTypeName, `:`)
					p.In()
					p.P(`thisType = `, i)
					p.Out()
				}
			}
			p.P(`default:`)
			p.In()
			p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", this.`, fieldname, `))`)
			p.Out()
			p.P(`}`)

			p.P(`that1Type := -1`)
			p.P(`switch that1.`, fieldname, `.(type) {`)
			for i, subfield := range message.Field {
				if *subfield.OneofIndex == *field.OneofIndex {
					ccTypeName := p.OneOfTypeName(message, subfield)
					p.P(`case *`, ccTypeName, `:`)
					p.In()
					p.P(`that1Type = `, i)
					p.Out()
				}
			}
			p.P(`default:`)
			p.In()
			p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", that1.`, fieldname, `))`)
			p.Out()
			p.P(`}`)

			p.P(`if thisType == that1Type {`)
			p.In()
			p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
			p.In()
			p.P(`return c`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`} else if thisType < that1Type {`)
			p.In()
			p.P(`return -1`)
			p.Out()
			p.P(`} else if thisType > that1Type {`)
			p.In()
			p.P(`return 1`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
		} else {
			p.generateField(file, message, field)
		}
	}
	if message.DescriptorProto.HasExtension() {
		if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
			p.P(`thismap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(this)`)
			p.P(`thatmap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(that1)`)
			p.P(`extkeys := make([]int32, 0, len(thismap)+len(thatmap))`)
			p.P(`for k, _ := range thismap {`)
			p.In()
			p.P(`extkeys = append(extkeys, k)`)
			p.Out()
			p.P(`}`)
			p.P(`for k, _ := range thatmap {`)
			p.In()
			p.P(`if _, ok := thismap[k]; !ok {`)
			p.In()
			p.P(`extkeys = append(extkeys, k)`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
			p.P(p.sortkeysPkg.Use(), `.Int32s(extkeys)`)
			p.P(`for _, k := range extkeys {`)
			p.In()
			p.P(`if v, ok := thismap[k]; ok {`)
			p.In()
			p.P(`if v2, ok := thatmap[k]; ok {`)
			p.In()
			p.P(`if c := v.Compare(&v2); c != 0 {`)
			p.In()
			p.P(`return c`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`} else  {`)
			p.In()
			p.P(`return 1`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`} else {`)
			p.In()
			p.P(`return -1`)
			p.Out()
			p.P(`}`)
			p.Out()
			p.P(`}`)
		} else {
			fieldname := "XXX_extensions"
			p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
			p.In()
			p.P(`return c`)
			p.Out()
			p.P(`}`)
		}
	}
	if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
		fieldname := "XXX_unrecognized"
		p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
		p.In()
		p.P(`return c`)
		p.Out()
		p.P(`}`)
	}
	p.P(`return 0`)
	p.Out()
	p.P(`}`)

	//Generate Compare methods for oneof fields
	m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
	for _, field := range m.Field {
		oneof := field.OneofIndex != nil
		if !oneof {
			continue
		}
		ccTypeName := p.OneOfTypeName(message, field)
		p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
		p.In()

		p.generateMsgNullAndTypeCheck(ccTypeName)
		vanity.TurnOffNullableForNativeTypes(field)
		p.generateField(file, message, field)

		p.P(`return 0`)
		p.Out()
		p.P(`}`)
	}
}

func init() {
	generator.RegisterPlugin(NewPlugin())
}