| // Copyright 2017 The Bazel Authors. All rights reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // embed generates a .go file from the contents of a list of data files. It is |
| // invoked by go_embed_data as an action. |
| package main |
| |
| import ( |
| "archive/tar" |
| "archive/zip" |
| "bufio" |
| "errors" |
| "flag" |
| "fmt" |
| "io" |
| "log" |
| "os" |
| "path" |
| "path/filepath" |
| "strconv" |
| "strings" |
| "text/template" |
| "unicode/utf8" |
| ) |
| |
| var headerTpl = template.Must(template.New("embed").Parse(`// Generated by go_embed_data for {{.Label}}. DO NOT EDIT. |
| |
| package {{.Package}} |
| |
| `)) |
| |
| var multiFooterTpl = template.Must(template.New("embed").Parse(` |
| var {{.Var}} = map[string]{{.Type}}{ |
| {{- range $i, $f := .FoundSources}} |
| {{$.Key $f}}: {{$.Var}}_{{$i}}, |
| {{- end}} |
| } |
| |
| `)) |
| |
| func main() { |
| log.SetPrefix("embed: ") |
| log.SetFlags(0) // don't print timestamps |
| if err := run(os.Args); err != nil { |
| log.Fatal(err) |
| } |
| } |
| |
| type configuration struct { |
| Label, Package, Var string |
| Multi bool |
| sources []string |
| FoundSources []string |
| out, workspace string |
| flatten, unpack, strData bool |
| } |
| |
| func (c *configuration) Type() string { |
| if c.strData { |
| return "string" |
| } else { |
| return "[]byte" |
| } |
| } |
| |
| func (c *configuration) Key(filename string) string { |
| workspacePrefix := "external/" + c.workspace + "/" |
| key := filepath.FromSlash(strings.TrimPrefix(filename, workspacePrefix)) |
| if c.flatten { |
| key = path.Base(filename) |
| } |
| return strconv.Quote(key) |
| } |
| |
| func run(args []string) error { |
| c, err := newConfiguration(args) |
| if err != nil { |
| return err |
| } |
| |
| f, err := os.Create(c.out) |
| if err != nil { |
| return err |
| } |
| defer f.Close() |
| w := bufio.NewWriter(f) |
| defer w.Flush() |
| |
| if err := headerTpl.Execute(w, c); err != nil { |
| return err |
| } |
| |
| if c.Multi { |
| return embedMultipleFiles(c, w) |
| } |
| return embedSingleFile(c, w) |
| } |
| |
| func newConfiguration(args []string) (*configuration, error) { |
| var c configuration |
| flags := flag.NewFlagSet("embed", flag.ExitOnError) |
| flags.StringVar(&c.Label, "label", "", "Label of the rule being executed (required)") |
| flags.StringVar(&c.Package, "package", "", "Go package name (required)") |
| flags.StringVar(&c.Var, "var", "", "Variable name (required)") |
| flags.BoolVar(&c.Multi, "multi", false, "Whether the variable is a map or a single value") |
| flags.StringVar(&c.out, "out", "", "Go file to generate (required)") |
| flags.StringVar(&c.workspace, "workspace", "", "Name of the workspace (required)") |
| flags.BoolVar(&c.flatten, "flatten", false, "Whether to access files by base name") |
| flags.BoolVar(&c.strData, "string", false, "Whether to store contents as strings") |
| flags.BoolVar(&c.unpack, "unpack", false, "Whether to treat files as archives to unpack.") |
| flags.Parse(args[1:]) |
| if c.Label == "" { |
| return nil, errors.New("error: -label option not provided") |
| } |
| if c.Package == "" { |
| return nil, errors.New("error: -package option not provided") |
| } |
| if c.Var == "" { |
| return nil, errors.New("error: -var option not provided") |
| } |
| if c.out == "" { |
| return nil, errors.New("error: -out option not provided") |
| } |
| if c.workspace == "" { |
| return nil, errors.New("error: -workspace option not provided") |
| } |
| c.sources = flags.Args() |
| if !c.Multi && len(c.sources) != 1 { |
| return nil, fmt.Errorf("error: -multi flag not given, so want exactly one source; got %d", len(c.sources)) |
| } |
| if c.unpack { |
| if !c.Multi { |
| return nil, errors.New("error: -multi flag is required for -unpack mode.") |
| } |
| for _, src := range c.sources { |
| if ext := filepath.Ext(src); ext != ".zip" && ext != ".tar" { |
| return nil, fmt.Errorf("error: -unpack flag expects .zip or .tar extension (got %q)", ext) |
| } |
| } |
| } |
| return &c, nil |
| } |
| |
| func embedSingleFile(c *configuration, w io.Writer) error { |
| dataBegin, dataEnd := "\"", "\"\n" |
| if !c.strData { |
| dataBegin, dataEnd = "[]byte(\"", "\")\n" |
| } |
| |
| if _, err := fmt.Fprintf(w, "var %s = %s", c.Var, dataBegin); err != nil { |
| return err |
| } |
| if err := embedFileContents(w, c.sources[0]); err != nil { |
| return err |
| } |
| _, err := fmt.Fprint(w, dataEnd) |
| return err |
| } |
| |
| func embedMultipleFiles(c *configuration, w io.Writer) error { |
| dataBegin, dataEnd := "\"", "\"\n" |
| if !c.strData { |
| dataBegin, dataEnd = "[]byte(\"", "\")\n" |
| } |
| |
| if _, err := fmt.Fprint(w, "var (\n"); err != nil { |
| return err |
| } |
| if err := findSources(c, func(i int, f io.Reader) error { |
| if _, err := fmt.Fprintf(w, "\t%s_%d = %s", c.Var, i, dataBegin); err != nil { |
| return err |
| } |
| if _, err := io.Copy(&escapeWriter{w}, f); err != nil { |
| return err |
| } |
| if _, err := fmt.Fprint(w, dataEnd); err != nil { |
| return err |
| } |
| return nil |
| }); err != nil { |
| return err |
| } |
| if _, err := fmt.Fprint(w, ")\n"); err != nil { |
| return err |
| } |
| if err := multiFooterTpl.Execute(w, c); err != nil { |
| return err |
| } |
| return nil |
| } |
| |
| func findSources(c *configuration, cb func(i int, f io.Reader) error) error { |
| if c.unpack { |
| for _, filename := range c.sources { |
| ext := filepath.Ext(filename) |
| if ext == ".zip" { |
| if err := findZipSources(c, filename, cb); err != nil { |
| return err |
| } |
| } else if ext == ".tar" { |
| if err := findTarSources(c, filename, cb); err != nil { |
| return err |
| } |
| } else { |
| panic("unknown archive extension: " + ext) |
| } |
| } |
| return nil |
| } |
| for _, filename := range c.sources { |
| f, err := os.Open(filename) |
| if err != nil { |
| return err |
| } |
| err = cb(len(c.FoundSources), bufio.NewReader(f)) |
| f.Close() |
| if err != nil { |
| return err |
| } |
| c.FoundSources = append(c.FoundSources, filename) |
| } |
| return nil |
| } |
| |
| func findZipSources(c *configuration, filename string, cb func(i int, f io.Reader) error) error { |
| r, err := zip.OpenReader(filename) |
| if err != nil { |
| return err |
| } |
| defer r.Close() |
| for _, file := range r.File { |
| f, err := file.Open() |
| if err != nil { |
| return err |
| } |
| err = cb(len(c.FoundSources), f) |
| f.Close() |
| if err != nil { |
| return err |
| } |
| c.FoundSources = append(c.FoundSources, file.Name) |
| } |
| return nil |
| } |
| |
| func findTarSources(c *configuration, filename string, cb func(i int, f io.Reader) error) error { |
| tf, err := os.Open(filename) |
| if err != nil { |
| return err |
| } |
| defer tf.Close() |
| reader := tar.NewReader(bufio.NewReader(tf)) |
| for { |
| h, err := reader.Next() |
| if err == io.EOF { |
| return nil |
| } |
| if err != nil { |
| return err |
| } |
| if h.Typeflag != tar.TypeReg { |
| continue |
| } |
| if err := cb(len(c.FoundSources), &io.LimitedReader{ |
| R: reader, |
| N: h.Size, |
| }); err != nil { |
| return err |
| } |
| c.FoundSources = append(c.FoundSources, h.Name) |
| } |
| } |
| |
| func embedFileContents(w io.Writer, filename string) error { |
| f, err := os.Open(filename) |
| if err != nil { |
| return err |
| } |
| defer f.Close() |
| |
| _, err = io.Copy(&escapeWriter{w}, bufio.NewReader(f)) |
| return err |
| } |
| |
| type escapeWriter struct { |
| w io.Writer |
| } |
| |
| func (w *escapeWriter) Write(data []byte) (n int, err error) { |
| n = len(data) |
| |
| for err == nil && len(data) > 0 { |
| // https://golang.org/ref/spec#String_literals: "Within the quotes, any |
| // character may appear except newline and unescaped double quote. The |
| // text between the quotes forms the value of the literal, with backslash |
| // escapes interpreted as they are in rune literals […]." |
| switch b := data[0]; b { |
| case '\\': |
| _, err = w.w.Write([]byte(`\\`)) |
| case '"': |
| _, err = w.w.Write([]byte(`\"`)) |
| case '\n': |
| _, err = w.w.Write([]byte(`\n`)) |
| |
| case '\x00': |
| // https://golang.org/ref/spec#Source_code_representation: "Implementation |
| // restriction: For compatibility with other tools, a compiler may |
| // disallow the NUL character (U+0000) in the source text." |
| _, err = w.w.Write([]byte(`\x00`)) |
| |
| default: |
| // https://golang.org/ref/spec#Source_code_representation: "Implementation |
| // restriction: […] A byte order mark may be disallowed anywhere else in |
| // the source." |
| const byteOrderMark = '\uFEFF' |
| |
| if r, size := utf8.DecodeRune(data); r != utf8.RuneError && r != byteOrderMark { |
| _, err = w.w.Write(data[:size]) |
| data = data[size:] |
| continue |
| } |
| |
| _, err = fmt.Fprintf(w.w, `\x%02x`, b) |
| } |
| data = data[1:] |
| } |
| |
| return n - len(data), err |
| } |