330 lines
8.1 KiB
Go
330 lines
8.1 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/token"
|
|
"go/types"
|
|
"log"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"golang.org/x/tools/go/packages"
|
|
"golang.org/x/tools/imports"
|
|
"tailscale.com/util/codegen"
|
|
)
|
|
|
|
var (
|
|
flagTypes = flag.String("type", "", "comma-separated list of types; required")
|
|
flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
|
|
)
|
|
|
|
func main() {
|
|
log.SetFlags(0)
|
|
log.SetPrefix("cloner: ")
|
|
log.SetOutput(os.Stderr)
|
|
flag.Parse()
|
|
if len(*flagTypes) == 0 {
|
|
flag.Usage()
|
|
os.Exit(2)
|
|
}
|
|
typeNames := strings.Split(*flagTypes, ",")
|
|
|
|
pkg, namedTypes, err := loadTypes(".", *flagBuildTags)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
it := codegen.NewImportTracker(pkg.Types)
|
|
buf := new(bytes.Buffer)
|
|
|
|
for _, typeName := range typeNames {
|
|
typ, ok := namedTypes[typeName]
|
|
if !ok {
|
|
log.Fatalf("could not find type %s", typeName)
|
|
}
|
|
gen(buf, it, typ)
|
|
}
|
|
|
|
outBuf := new(bytes.Buffer)
|
|
outBuf.WriteString("// Code generated by TODO; DO NOT EDIT.\n")
|
|
outBuf.WriteString("\n")
|
|
fmt.Fprintf(outBuf, "package %s\n\n", pkg.Name)
|
|
it.Write(outBuf)
|
|
outBuf.Write(buf.Bytes())
|
|
|
|
// Best-effort gofmt the output
|
|
out := outBuf.Bytes()
|
|
out, err = imports.Process("/nonexistant/main.go", out, &imports.Options{
|
|
Comments: true,
|
|
TabIndent: true,
|
|
TabWidth: 8,
|
|
FormatOnly: true, // fancy gofmt only
|
|
})
|
|
if err != nil {
|
|
out = outBuf.Bytes()
|
|
}
|
|
fmt.Print(string(out))
|
|
}
|
|
|
|
func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
|
|
t, ok := typ.Underlying().(*types.Struct)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
name := typ.Obj().Name()
|
|
fmt.Fprintf(buf, "// MarshalJSONInto marshals this %s into JSON in the provided buffer.\n", name)
|
|
fmt.Fprintf(buf, "func (self *%s) MarshalJSONInto(buf []byte) ([]byte, error) {\n", name)
|
|
fmt.Fprintf(buf, "\tvar err error\n")
|
|
fmt.Fprintf(buf, "\t_ = err\n")
|
|
|
|
g := &generator{
|
|
buf: buf,
|
|
it: it,
|
|
indentLevel: 1,
|
|
}
|
|
|
|
g.writef(`buf = append(buf, '{')`)
|
|
for i := 0; i < t.NumFields(); i++ {
|
|
fname := t.Field(i).Name()
|
|
ft := t.Field(i).Type()
|
|
|
|
g.writef("")
|
|
g.writef(`// Encode field %s of type %q`, fname, ft.String())
|
|
|
|
// Write the field name; we need to quote the field (for JSON)
|
|
// and then quote it again (for the generated Go code).
|
|
qfname := strconv.Quote(fname) + ":"
|
|
g.writef(`buf = append(buf, []byte(%q)...)`, qfname)
|
|
|
|
// Write the value
|
|
g.encode("self."+fname, ft)
|
|
|
|
if i < t.NumFields()-1 {
|
|
g.writef(`buf = append(buf, ',')`)
|
|
}
|
|
}
|
|
g.writef(`buf = append(buf, '}')`)
|
|
|
|
g.writef("return buf, nil")
|
|
fmt.Fprintf(buf, "}\n\n")
|
|
}
|
|
|
|
type generator struct {
|
|
buf *bytes.Buffer
|
|
it *codegen.ImportTracker
|
|
indentLevel int
|
|
}
|
|
|
|
func (g *generator) writef(format string, args ...any) {
|
|
fmt.Fprintf(g.buf, strings.Repeat("\t", g.indentLevel)+format+"\n", args...)
|
|
}
|
|
|
|
func (g *generator) indent() {
|
|
g.indentLevel++
|
|
}
|
|
|
|
func (g *generator) dedent() {
|
|
g.indentLevel--
|
|
}
|
|
|
|
func (g *generator) encode(accessor string, ft types.Type) {
|
|
switch ft := ft.Underlying().(type) {
|
|
case *types.Basic:
|
|
g.encodeBasicField(accessor, ft)
|
|
case *types.Slice:
|
|
g.encodeSlice(accessor, ft)
|
|
case *types.Map:
|
|
g.encodeMap(accessor, ft)
|
|
case *types.Struct:
|
|
g.encodeStruct(accessor)
|
|
case *types.Pointer:
|
|
g.encodePointer(accessor, ft)
|
|
default:
|
|
g.writef(`panic("TODO: %s (%T)")`, accessor, ft)
|
|
}
|
|
}
|
|
|
|
func (g *generator) encodePointer(accessor string, ft *types.Pointer) {
|
|
g.writef("if %s != nil {", accessor)
|
|
g.indent()
|
|
// Don't deref for a struct, since we're going to call a function
|
|
// anyway; otherwise, do.
|
|
if _, ok := ft.Elem().Underlying().(*types.Struct); ok {
|
|
g.encode(accessor, ft.Elem())
|
|
} else {
|
|
g.encode("(*"+accessor+")", ft.Elem())
|
|
}
|
|
g.dedent()
|
|
g.writef("} else {")
|
|
g.writef("\tbuf = append(buf, []byte(\"null\")...)")
|
|
g.writef("}")
|
|
}
|
|
|
|
func (g *generator) encodeMap(accessor string, ft *types.Map) {
|
|
kt := ft.Key().Underlying()
|
|
vt := ft.Elem().Underlying()
|
|
|
|
g.writef(`buf = append(buf, '{')`)
|
|
|
|
// Determine how we marshal our key type
|
|
marshalKey := func() {
|
|
g.encode("k", kt)
|
|
}
|
|
|
|
// Now check how we marshal our value
|
|
switch vt := vt.(type) {
|
|
case *types.Basic:
|
|
g.writef("for k, v := range %s {", accessor)
|
|
marshalKey()
|
|
g.writef("\tbuf = append(buf, ':')")
|
|
g.encodeBasicField("v", vt)
|
|
g.writef("}")
|
|
case *types.Struct:
|
|
g.writef("for k, v := range %s {", accessor)
|
|
marshalKey()
|
|
g.writef("\tbuf = append(buf, ':')")
|
|
g.encodeStruct("v")
|
|
g.writef("}")
|
|
default:
|
|
g.writef(`panic("TODO: %s (%T)")`, accessor, vt)
|
|
}
|
|
|
|
g.writef(`buf = append(buf, '}')`)
|
|
}
|
|
|
|
func (g *generator) encodeStruct(accessor string) {
|
|
// Assume that this struct also has a MarshalJSONInto method.
|
|
g.writef("buf, err = %s.MarshalJSONInto(buf)", accessor)
|
|
g.writef("if err != nil {")
|
|
g.writef("\treturn nil, err")
|
|
g.writef("}")
|
|
}
|
|
|
|
func (g *generator) encodeSlice(accessor string, sl *types.Slice) {
|
|
switch ft := sl.Elem().Underlying().(type) {
|
|
case *types.Basic:
|
|
// Slice of basic elements
|
|
switch ft.Kind() {
|
|
case types.Byte:
|
|
// base64-encode
|
|
g.it.Import("encoding/base64")
|
|
|
|
g.writef(`buf = append(buf, '"')`)
|
|
g.writef("{")
|
|
|
|
// buf = append(buf, make([]byte, N)...) is a fast way to grow the slice by N
|
|
g.writef("encodedLen := base64.StdEncoding.EncodedLen(len(%s))", accessor)
|
|
g.writef("offset := len(buf)")
|
|
g.writef("buf = append(buf, make([]byte, encodedLen)...)")
|
|
g.writef("base64.StdEncoding.Encode(buf[offset:], %s)", accessor)
|
|
|
|
g.writef("}")
|
|
g.writef(`buf = append(buf, '"')`)
|
|
default:
|
|
// All other basic elements are encoded
|
|
// one at a time via encodeBasicField
|
|
g.writef(`buf = append(buf, '[')`)
|
|
g.writef(`for i, elem := range %s {`, accessor)
|
|
g.writef("\tif i > 0 {")
|
|
g.writef("\t\tbuf = append(buf, ',')")
|
|
g.writef("\t}")
|
|
g.encodeBasicField("elem", ft)
|
|
g.writef(`}`)
|
|
g.writef(`buf = append(buf, ']')`)
|
|
}
|
|
|
|
case *types.Struct:
|
|
g.writef(`buf = append(buf, '[')`)
|
|
g.writef(`for i, elem := range %s {`, accessor)
|
|
g.writef("\tif i > 0 {")
|
|
g.writef("\t\tbuf = append(buf, ',')")
|
|
g.writef("\t}")
|
|
g.encodeStruct("elem")
|
|
g.writef(`}`)
|
|
g.writef(`buf = append(buf, ']')`)
|
|
|
|
default:
|
|
// TODO: if the type implements our interface,
|
|
// call that function for everything in the
|
|
// slice.
|
|
g.writef(`panic("TODO: %s (%T)")`, accessor, ft)
|
|
}
|
|
}
|
|
|
|
func (g *generator) encodeBasicField(accessor string, field *types.Basic) {
|
|
switch field.Kind() {
|
|
case types.Bool:
|
|
g.writef("if %s {", accessor)
|
|
g.writef(`buf = append(buf, []byte("true")...)`)
|
|
g.writef("} else {")
|
|
g.writef(`buf = append(buf, []byte("false")...)`)
|
|
g.writef("}")
|
|
case types.Int, types.Int8, types.Int16, types.Int32, types.Int64:
|
|
g.it.Import("strconv")
|
|
g.writef("buf = strconv.AppendInt(buf, int64(%s), 10)", accessor)
|
|
case types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64:
|
|
g.it.Import("strconv")
|
|
g.writef("buf = strconv.AppendUint(buf, uint64(%s), 10)", accessor)
|
|
case types.String:
|
|
g.it.Import("strconv")
|
|
g.writef("buf = strconv.AppendQuote(buf, %s)", accessor)
|
|
default:
|
|
g.writef(`panic("TODO: %s (%T)")`, accessor, field.Kind)
|
|
}
|
|
}
|
|
|
|
func loadTypes(pkgName, buildTags string) (*packages.Package, map[string]*types.Named, error) {
|
|
cfg := &packages.Config{
|
|
Mode: packages.NeedTypes |
|
|
packages.NeedTypesInfo |
|
|
packages.NeedSyntax |
|
|
packages.NeedName,
|
|
Tests: false,
|
|
}
|
|
if buildTags != "" {
|
|
cfg.BuildFlags = []string{"-tags=" + buildTags}
|
|
}
|
|
|
|
pkgs, err := packages.Load(cfg, pkgName)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if len(pkgs) != 1 {
|
|
return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
|
|
}
|
|
pkg := pkgs[0]
|
|
return pkg, namedTypes(pkg), nil
|
|
}
|
|
|
|
func namedTypes(pkg *packages.Package) map[string]*types.Named {
|
|
nt := make(map[string]*types.Named)
|
|
for _, file := range pkg.Syntax {
|
|
for _, d := range file.Decls {
|
|
decl, ok := d.(*ast.GenDecl)
|
|
if !ok || decl.Tok != token.TYPE {
|
|
continue
|
|
}
|
|
for _, s := range decl.Specs {
|
|
spec, ok := s.(*ast.TypeSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name]
|
|
if !ok {
|
|
continue
|
|
}
|
|
typ, ok := typeNameObj.Type().(*types.Named)
|
|
if !ok {
|
|
continue
|
|
}
|
|
nt[spec.Name.Name] = typ
|
|
}
|
|
}
|
|
}
|
|
return nt
|
|
}
|