tailscale/cmd/fastjson/fastjson.go

330 lines
8.1 KiB
Go

package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/token"
"go/types"
"log"
"os"
"strconv"
"strings"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
"tailscale.com/util/codegen"
)
var (
flagTypes = flag.String("type", "", "comma-separated list of types; required")
flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
)
func main() {
log.SetFlags(0)
log.SetPrefix("cloner: ")
log.SetOutput(os.Stderr)
flag.Parse()
if len(*flagTypes) == 0 {
flag.Usage()
os.Exit(2)
}
typeNames := strings.Split(*flagTypes, ",")
pkg, namedTypes, err := loadTypes(".", *flagBuildTags)
if err != nil {
log.Fatal(err)
}
it := codegen.NewImportTracker(pkg.Types)
buf := new(bytes.Buffer)
for _, typeName := range typeNames {
typ, ok := namedTypes[typeName]
if !ok {
log.Fatalf("could not find type %s", typeName)
}
gen(buf, it, typ)
}
outBuf := new(bytes.Buffer)
outBuf.WriteString("// Code generated by TODO; DO NOT EDIT.\n")
outBuf.WriteString("\n")
fmt.Fprintf(outBuf, "package %s\n\n", pkg.Name)
it.Write(outBuf)
outBuf.Write(buf.Bytes())
// Best-effort gofmt the output
out := outBuf.Bytes()
out, err = imports.Process("/nonexistant/main.go", out, &imports.Options{
Comments: true,
TabIndent: true,
TabWidth: 8,
FormatOnly: true, // fancy gofmt only
})
if err != nil {
out = outBuf.Bytes()
}
fmt.Print(string(out))
}
func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
t, ok := typ.Underlying().(*types.Struct)
if !ok {
return
}
name := typ.Obj().Name()
fmt.Fprintf(buf, "// MarshalJSONInto marshals this %s into JSON in the provided buffer.\n", name)
fmt.Fprintf(buf, "func (self *%s) MarshalJSONInto(buf []byte) ([]byte, error) {\n", name)
fmt.Fprintf(buf, "\tvar err error\n")
fmt.Fprintf(buf, "\t_ = err\n")
g := &generator{
buf: buf,
it: it,
indentLevel: 1,
}
g.writef(`buf = append(buf, '{')`)
for i := 0; i < t.NumFields(); i++ {
fname := t.Field(i).Name()
ft := t.Field(i).Type()
g.writef("")
g.writef(`// Encode field %s of type %q`, fname, ft.String())
// Write the field name; we need to quote the field (for JSON)
// and then quote it again (for the generated Go code).
qfname := strconv.Quote(fname) + ":"
g.writef(`buf = append(buf, []byte(%q)...)`, qfname)
// Write the value
g.encode("self."+fname, ft)
if i < t.NumFields()-1 {
g.writef(`buf = append(buf, ',')`)
}
}
g.writef(`buf = append(buf, '}')`)
g.writef("return buf, nil")
fmt.Fprintf(buf, "}\n\n")
}
type generator struct {
buf *bytes.Buffer
it *codegen.ImportTracker
indentLevel int
}
func (g *generator) writef(format string, args ...any) {
fmt.Fprintf(g.buf, strings.Repeat("\t", g.indentLevel)+format+"\n", args...)
}
func (g *generator) indent() {
g.indentLevel++
}
func (g *generator) dedent() {
g.indentLevel--
}
func (g *generator) encode(accessor string, ft types.Type) {
switch ft := ft.Underlying().(type) {
case *types.Basic:
g.encodeBasicField(accessor, ft)
case *types.Slice:
g.encodeSlice(accessor, ft)
case *types.Map:
g.encodeMap(accessor, ft)
case *types.Struct:
g.encodeStruct(accessor)
case *types.Pointer:
g.encodePointer(accessor, ft)
default:
g.writef(`panic("TODO: %s (%T)")`, accessor, ft)
}
}
func (g *generator) encodePointer(accessor string, ft *types.Pointer) {
g.writef("if %s != nil {", accessor)
g.indent()
// Don't deref for a struct, since we're going to call a function
// anyway; otherwise, do.
if _, ok := ft.Elem().Underlying().(*types.Struct); ok {
g.encode(accessor, ft.Elem())
} else {
g.encode("(*"+accessor+")", ft.Elem())
}
g.dedent()
g.writef("} else {")
g.writef("\tbuf = append(buf, []byte(\"null\")...)")
g.writef("}")
}
func (g *generator) encodeMap(accessor string, ft *types.Map) {
kt := ft.Key().Underlying()
vt := ft.Elem().Underlying()
g.writef(`buf = append(buf, '{')`)
// Determine how we marshal our key type
marshalKey := func() {
g.encode("k", kt)
}
// Now check how we marshal our value
switch vt := vt.(type) {
case *types.Basic:
g.writef("for k, v := range %s {", accessor)
marshalKey()
g.writef("\tbuf = append(buf, ':')")
g.encodeBasicField("v", vt)
g.writef("}")
case *types.Struct:
g.writef("for k, v := range %s {", accessor)
marshalKey()
g.writef("\tbuf = append(buf, ':')")
g.encodeStruct("v")
g.writef("}")
default:
g.writef(`panic("TODO: %s (%T)")`, accessor, vt)
}
g.writef(`buf = append(buf, '}')`)
}
func (g *generator) encodeStruct(accessor string) {
// Assume that this struct also has a MarshalJSONInto method.
g.writef("buf, err = %s.MarshalJSONInto(buf)", accessor)
g.writef("if err != nil {")
g.writef("\treturn nil, err")
g.writef("}")
}
func (g *generator) encodeSlice(accessor string, sl *types.Slice) {
switch ft := sl.Elem().Underlying().(type) {
case *types.Basic:
// Slice of basic elements
switch ft.Kind() {
case types.Byte:
// base64-encode
g.it.Import("encoding/base64")
g.writef(`buf = append(buf, '"')`)
g.writef("{")
// buf = append(buf, make([]byte, N)...) is a fast way to grow the slice by N
g.writef("encodedLen := base64.StdEncoding.EncodedLen(len(%s))", accessor)
g.writef("offset := len(buf)")
g.writef("buf = append(buf, make([]byte, encodedLen)...)")
g.writef("base64.StdEncoding.Encode(buf[offset:], %s)", accessor)
g.writef("}")
g.writef(`buf = append(buf, '"')`)
default:
// All other basic elements are encoded
// one at a time via encodeBasicField
g.writef(`buf = append(buf, '[')`)
g.writef(`for i, elem := range %s {`, accessor)
g.writef("\tif i > 0 {")
g.writef("\t\tbuf = append(buf, ',')")
g.writef("\t}")
g.encodeBasicField("elem", ft)
g.writef(`}`)
g.writef(`buf = append(buf, ']')`)
}
case *types.Struct:
g.writef(`buf = append(buf, '[')`)
g.writef(`for i, elem := range %s {`, accessor)
g.writef("\tif i > 0 {")
g.writef("\t\tbuf = append(buf, ',')")
g.writef("\t}")
g.encodeStruct("elem")
g.writef(`}`)
g.writef(`buf = append(buf, ']')`)
default:
// TODO: if the type implements our interface,
// call that function for everything in the
// slice.
g.writef(`panic("TODO: %s (%T)")`, accessor, ft)
}
}
func (g *generator) encodeBasicField(accessor string, field *types.Basic) {
switch field.Kind() {
case types.Bool:
g.writef("if %s {", accessor)
g.writef(`buf = append(buf, []byte("true")...)`)
g.writef("} else {")
g.writef(`buf = append(buf, []byte("false")...)`)
g.writef("}")
case types.Int, types.Int8, types.Int16, types.Int32, types.Int64:
g.it.Import("strconv")
g.writef("buf = strconv.AppendInt(buf, int64(%s), 10)", accessor)
case types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64:
g.it.Import("strconv")
g.writef("buf = strconv.AppendUint(buf, uint64(%s), 10)", accessor)
case types.String:
g.it.Import("strconv")
g.writef("buf = strconv.AppendQuote(buf, %s)", accessor)
default:
g.writef(`panic("TODO: %s (%T)")`, accessor, field.Kind)
}
}
func loadTypes(pkgName, buildTags string) (*packages.Package, map[string]*types.Named, error) {
cfg := &packages.Config{
Mode: packages.NeedTypes |
packages.NeedTypesInfo |
packages.NeedSyntax |
packages.NeedName,
Tests: false,
}
if buildTags != "" {
cfg.BuildFlags = []string{"-tags=" + buildTags}
}
pkgs, err := packages.Load(cfg, pkgName)
if err != nil {
return nil, nil, err
}
if len(pkgs) != 1 {
return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
}
pkg := pkgs[0]
return pkg, namedTypes(pkg), nil
}
func namedTypes(pkg *packages.Package) map[string]*types.Named {
nt := make(map[string]*types.Named)
for _, file := range pkg.Syntax {
for _, d := range file.Decls {
decl, ok := d.(*ast.GenDecl)
if !ok || decl.Tok != token.TYPE {
continue
}
for _, s := range decl.Specs {
spec, ok := s.(*ast.TypeSpec)
if !ok {
continue
}
typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name]
if !ok {
continue
}
typ, ok := typeNameObj.Type().(*types.Named)
if !ok {
continue
}
nt[spec.Name.Name] = typ
}
}
}
return nt
}