blob: e68da974436bc8011daf76cce67e67ceea0181ae [file] [log] [blame] [edit]
// Copyright 2017 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// embed generates a .go file from the contents of a list of data files. It is
// invoked by go_embed_data as an action.
package main
import (
"archive/tar"
"archive/zip"
"bufio"
"errors"
"flag"
"fmt"
"io"
"log"
"os"
"path"
"path/filepath"
"strconv"
"strings"
"text/template"
"unicode/utf8"
)
var headerTpl = template.Must(template.New("embed").Parse(`// Generated by go_embed_data for {{.Label}}. DO NOT EDIT.
package {{.Package}}
`))
var multiFooterTpl = template.Must(template.New("embed").Parse(`
var {{.Var}} = map[string]{{.Type}}{
{{- range $i, $f := .FoundSources}}
{{$.Key $f}}: {{$.Var}}_{{$i}},
{{- end}}
}
`))
func main() {
log.SetPrefix("embed: ")
log.SetFlags(0) // don't print timestamps
if err := run(os.Args); err != nil {
log.Fatal(err)
}
}
type configuration struct {
Label, Package, Var string
Multi bool
sources []string
FoundSources []string
out, workspace string
flatten, unpack, strData bool
}
func (c *configuration) Type() string {
if c.strData {
return "string"
} else {
return "[]byte"
}
}
func (c *configuration) Key(filename string) string {
workspacePrefix := "external/" + c.workspace + "/"
key := filepath.FromSlash(strings.TrimPrefix(filename, workspacePrefix))
if c.flatten {
key = path.Base(filename)
}
return strconv.Quote(key)
}
func run(args []string) error {
c, err := newConfiguration(args)
if err != nil {
return err
}
f, err := os.Create(c.out)
if err != nil {
return err
}
defer f.Close()
w := bufio.NewWriter(f)
defer w.Flush()
if err := headerTpl.Execute(w, c); err != nil {
return err
}
if c.Multi {
return embedMultipleFiles(c, w)
}
return embedSingleFile(c, w)
}
func newConfiguration(args []string) (*configuration, error) {
var c configuration
flags := flag.NewFlagSet("embed", flag.ExitOnError)
flags.StringVar(&c.Label, "label", "", "Label of the rule being executed (required)")
flags.StringVar(&c.Package, "package", "", "Go package name (required)")
flags.StringVar(&c.Var, "var", "", "Variable name (required)")
flags.BoolVar(&c.Multi, "multi", false, "Whether the variable is a map or a single value")
flags.StringVar(&c.out, "out", "", "Go file to generate (required)")
flags.StringVar(&c.workspace, "workspace", "", "Name of the workspace (required)")
flags.BoolVar(&c.flatten, "flatten", false, "Whether to access files by base name")
flags.BoolVar(&c.strData, "string", false, "Whether to store contents as strings")
flags.BoolVar(&c.unpack, "unpack", false, "Whether to treat files as archives to unpack.")
flags.Parse(args[1:])
if c.Label == "" {
return nil, errors.New("error: -label option not provided")
}
if c.Package == "" {
return nil, errors.New("error: -package option not provided")
}
if c.Var == "" {
return nil, errors.New("error: -var option not provided")
}
if c.out == "" {
return nil, errors.New("error: -out option not provided")
}
if c.workspace == "" {
return nil, errors.New("error: -workspace option not provided")
}
c.sources = flags.Args()
if !c.Multi && len(c.sources) != 1 {
return nil, fmt.Errorf("error: -multi flag not given, so want exactly one source; got %d", len(c.sources))
}
if c.unpack {
if !c.Multi {
return nil, errors.New("error: -multi flag is required for -unpack mode.")
}
for _, src := range c.sources {
if ext := filepath.Ext(src); ext != ".zip" && ext != ".tar" {
return nil, fmt.Errorf("error: -unpack flag expects .zip or .tar extension (got %q)", ext)
}
}
}
return &c, nil
}
func embedSingleFile(c *configuration, w io.Writer) error {
dataBegin, dataEnd := "\"", "\"\n"
if !c.strData {
dataBegin, dataEnd = "[]byte(\"", "\")\n"
}
if _, err := fmt.Fprintf(w, "var %s = %s", c.Var, dataBegin); err != nil {
return err
}
if err := embedFileContents(w, c.sources[0]); err != nil {
return err
}
_, err := fmt.Fprint(w, dataEnd)
return err
}
func embedMultipleFiles(c *configuration, w io.Writer) error {
dataBegin, dataEnd := "\"", "\"\n"
if !c.strData {
dataBegin, dataEnd = "[]byte(\"", "\")\n"
}
if _, err := fmt.Fprint(w, "var (\n"); err != nil {
return err
}
if err := findSources(c, func(i int, f io.Reader) error {
if _, err := fmt.Fprintf(w, "\t%s_%d = %s", c.Var, i, dataBegin); err != nil {
return err
}
if _, err := io.Copy(&escapeWriter{w}, f); err != nil {
return err
}
if _, err := fmt.Fprint(w, dataEnd); err != nil {
return err
}
return nil
}); err != nil {
return err
}
if _, err := fmt.Fprint(w, ")\n"); err != nil {
return err
}
if err := multiFooterTpl.Execute(w, c); err != nil {
return err
}
return nil
}
func findSources(c *configuration, cb func(i int, f io.Reader) error) error {
if c.unpack {
for _, filename := range c.sources {
ext := filepath.Ext(filename)
if ext == ".zip" {
if err := findZipSources(c, filename, cb); err != nil {
return err
}
} else if ext == ".tar" {
if err := findTarSources(c, filename, cb); err != nil {
return err
}
} else {
panic("unknown archive extension: " + ext)
}
}
return nil
}
for _, filename := range c.sources {
f, err := os.Open(filename)
if err != nil {
return err
}
err = cb(len(c.FoundSources), bufio.NewReader(f))
f.Close()
if err != nil {
return err
}
c.FoundSources = append(c.FoundSources, filename)
}
return nil
}
func findZipSources(c *configuration, filename string, cb func(i int, f io.Reader) error) error {
r, err := zip.OpenReader(filename)
if err != nil {
return err
}
defer r.Close()
for _, file := range r.File {
f, err := file.Open()
if err != nil {
return err
}
err = cb(len(c.FoundSources), f)
f.Close()
if err != nil {
return err
}
c.FoundSources = append(c.FoundSources, file.Name)
}
return nil
}
func findTarSources(c *configuration, filename string, cb func(i int, f io.Reader) error) error {
tf, err := os.Open(filename)
if err != nil {
return err
}
defer tf.Close()
reader := tar.NewReader(bufio.NewReader(tf))
for {
h, err := reader.Next()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
if h.Typeflag != tar.TypeReg {
continue
}
if err := cb(len(c.FoundSources), &io.LimitedReader{
R: reader,
N: h.Size,
}); err != nil {
return err
}
c.FoundSources = append(c.FoundSources, h.Name)
}
}
func embedFileContents(w io.Writer, filename string) error {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(&escapeWriter{w}, bufio.NewReader(f))
return err
}
type escapeWriter struct {
w io.Writer
}
func (w *escapeWriter) Write(data []byte) (n int, err error) {
n = len(data)
for err == nil && len(data) > 0 {
// https://golang.org/ref/spec#String_literals: "Within the quotes, any
// character may appear except newline and unescaped double quote. The
// text between the quotes forms the value of the literal, with backslash
// escapes interpreted as they are in rune literals […]."
switch b := data[0]; b {
case '\\':
_, err = w.w.Write([]byte(`\\`))
case '"':
_, err = w.w.Write([]byte(`\"`))
case '\n':
_, err = w.w.Write([]byte(`\n`))
case '\x00':
// https://golang.org/ref/spec#Source_code_representation: "Implementation
// restriction: For compatibility with other tools, a compiler may
// disallow the NUL character (U+0000) in the source text."
_, err = w.w.Write([]byte(`\x00`))
default:
// https://golang.org/ref/spec#Source_code_representation: "Implementation
// restriction: […] A byte order mark may be disallowed anywhere else in
// the source."
const byteOrderMark = '\uFEFF'
if r, size := utf8.DecodeRune(data); r != utf8.RuneError && r != byteOrderMark {
_, err = w.w.Write(data[:size])
data = data[size:]
continue
}
_, err = fmt.Fprintf(w.w, `\x%02x`, b)
}
data = data[1:]
}
return n - len(data), err
}