//
// Copyright 2022 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// CompiledShaderState.cpp:
//   Implements CompiledShaderState, and helper functions for serializing and deserializing
//   shader variables.
//

#include "common/CompiledShaderState.h"

#include "common/BinaryStream.h"
#include "common/utilities.h"

namespace gl
{
namespace
{
template <typename VarT>
std::vector<VarT> GetActiveShaderVariables(const std::vector<VarT> *variableList)
{
    ASSERT(variableList);
    std::vector<VarT> result;
    for (size_t varIndex = 0; varIndex < variableList->size(); varIndex++)
    {
        const VarT &var = variableList->at(varIndex);
        if (var.active)
        {
            result.push_back(var);
        }
    }
    return result;
}

template <typename VarT>
const std::vector<VarT> &GetShaderVariables(const std::vector<VarT> *variableList)
{
    ASSERT(variableList);
    return *variableList;
}
}  // namespace

// true if varying x has a higher priority in packing than y
bool CompareShaderVar(const sh::ShaderVariable &x, const sh::ShaderVariable &y)
{
    if (x.type == y.type)
    {
        return x.getArraySizeProduct() > y.getArraySizeProduct();
    }

    // Special case for handling structs: we sort these to the end of the list
    if (x.type == GL_NONE)
    {
        return false;
    }

    if (y.type == GL_NONE)
    {
        return true;
    }

    return gl::VariableSortOrder(x.type) < gl::VariableSortOrder(y.type);
}

void WriteShaderVar(gl::BinaryOutputStream *stream, const sh::ShaderVariable &var)
{
    stream->writeInt(var.type);
    stream->writeInt(var.precision);
    stream->writeString(var.name);
    stream->writeString(var.mappedName);
    stream->writeIntVector(var.arraySizes);
    stream->writeBool(var.staticUse);
    stream->writeBool(var.active);
    stream->writeInt<size_t>(var.fields.size());
    for (const sh::ShaderVariable &shaderVariable : var.fields)
    {
        WriteShaderVar(stream, shaderVariable);
    }
    stream->writeString(var.structOrBlockName);
    stream->writeString(var.mappedStructOrBlockName);
    stream->writeBool(var.isRowMajorLayout);
    stream->writeInt(var.location);
    stream->writeBool(var.hasImplicitLocation);
    stream->writeInt(var.binding);
    stream->writeInt(var.imageUnitFormat);
    stream->writeInt(var.offset);
    stream->writeBool(var.rasterOrdered);
    stream->writeBool(var.readonly);
    stream->writeBool(var.writeonly);
    stream->writeBool(var.isFragmentInOut);
    stream->writeInt(var.index);
    stream->writeBool(var.yuv);
    stream->writeEnum(var.interpolation);
    stream->writeBool(var.isInvariant);
    stream->writeBool(var.isShaderIOBlock);
    stream->writeBool(var.isPatch);
    stream->writeBool(var.texelFetchStaticUse);
    stream->writeInt(var.getFlattenedOffsetInParentArrays());
}

void LoadShaderVar(gl::BinaryInputStream *stream, sh::ShaderVariable *var)
{
    var->type      = stream->readInt<GLenum>();
    var->precision = stream->readInt<GLenum>();
    stream->readString(&var->name);
    stream->readString(&var->mappedName);
    stream->readIntVector<unsigned int>(&var->arraySizes);
    var->staticUse      = stream->readBool();
    var->active         = stream->readBool();
    size_t elementCount = stream->readInt<size_t>();
    var->fields.resize(elementCount);
    for (sh::ShaderVariable &variable : var->fields)
    {
        LoadShaderVar(stream, &variable);
    }
    stream->readString(&var->structOrBlockName);
    stream->readString(&var->mappedStructOrBlockName);
    var->isRowMajorLayout    = stream->readBool();
    var->location            = stream->readInt<int>();
    var->hasImplicitLocation = stream->readBool();
    var->binding             = stream->readInt<int>();
    var->imageUnitFormat     = stream->readInt<GLenum>();
    var->offset              = stream->readInt<int>();
    var->rasterOrdered       = stream->readBool();
    var->readonly            = stream->readBool();
    var->writeonly           = stream->readBool();
    var->isFragmentInOut     = stream->readBool();
    var->index               = stream->readInt<int>();
    var->yuv                 = stream->readBool();
    var->interpolation       = stream->readEnum<sh::InterpolationType>();
    var->isInvariant         = stream->readBool();
    var->isShaderIOBlock     = stream->readBool();
    var->isPatch             = stream->readBool();
    var->texelFetchStaticUse = stream->readBool();
    var->setParentArrayIndex(stream->readInt<int>());
}

void WriteShInterfaceBlock(gl::BinaryOutputStream *stream, const sh::InterfaceBlock &block)
{
    stream->writeString(block.name);
    stream->writeString(block.mappedName);
    stream->writeString(block.instanceName);
    stream->writeInt(block.arraySize);
    stream->writeEnum(block.layout);
    stream->writeBool(block.isRowMajorLayout);
    stream->writeInt(block.binding);
    stream->writeBool(block.staticUse);
    stream->writeBool(block.active);
    stream->writeEnum(block.blockType);

    stream->writeInt<size_t>(block.fields.size());
    for (const sh::ShaderVariable &shaderVariable : block.fields)
    {
        WriteShaderVar(stream, shaderVariable);
    }
}

void LoadShInterfaceBlock(gl::BinaryInputStream *stream, sh::InterfaceBlock *block)
{
    block->name             = stream->readString();
    block->mappedName       = stream->readString();
    block->instanceName     = stream->readString();
    block->arraySize        = stream->readInt<unsigned int>();
    block->layout           = stream->readEnum<sh::BlockLayoutType>();
    block->isRowMajorLayout = stream->readBool();
    block->binding          = stream->readInt<int>();
    block->staticUse        = stream->readBool();
    block->active           = stream->readBool();
    block->blockType        = stream->readEnum<sh::BlockType>();

    block->fields.resize(stream->readInt<size_t>());
    for (sh::ShaderVariable &variable : block->fields)
    {
        LoadShaderVar(stream, &variable);
    }
}

CompiledShaderState::CompiledShaderState(gl::ShaderType type)
    : shaderType(type), shaderVersion(100), numViews(-1), geometryShaderInvocations(1)
{
    localSize.fill(-1);
}

CompiledShaderState::~CompiledShaderState() {}

void CompiledShaderState::buildCompiledShaderState(const ShHandle compilerHandle,
                                                   const bool isBinaryOutput)
{
    if (isBinaryOutput)
    {
        compiledBinary = sh::GetObjectBinaryBlob(compilerHandle);
    }
    else
    {
        translatedSource = sh::GetObjectCode(compilerHandle);
    }

    // Gather the shader information
    shaderVersion = sh::GetShaderVersion(compilerHandle);

    uniforms            = GetShaderVariables(sh::GetUniforms(compilerHandle));
    uniformBlocks       = GetShaderVariables(sh::GetUniformBlocks(compilerHandle));
    shaderStorageBlocks = GetShaderVariables(sh::GetShaderStorageBlocks(compilerHandle));
    specConstUsageBits  = SpecConstUsageBits(sh::GetShaderSpecConstUsageBits(compilerHandle));

    switch (shaderType)
    {
        case gl::ShaderType::Compute:
        {
            allAttributes    = GetShaderVariables(sh::GetAttributes(compilerHandle));
            activeAttributes = GetActiveShaderVariables(&allAttributes);
            localSize        = sh::GetComputeShaderLocalGroupSize(compilerHandle);
            break;
        }
        case gl::ShaderType::Vertex:
        {
            outputVaryings   = GetShaderVariables(sh::GetOutputVaryings(compilerHandle));
            allAttributes    = GetShaderVariables(sh::GetAttributes(compilerHandle));
            activeAttributes = GetActiveShaderVariables(&allAttributes);
            hasClipDistance  = sh::HasClipDistanceInVertexShader(compilerHandle);
            numViews         = sh::GetVertexShaderNumViews(compilerHandle);
            break;
        }
        case gl::ShaderType::Fragment:
        {
            allAttributes    = GetShaderVariables(sh::GetAttributes(compilerHandle));
            activeAttributes = GetActiveShaderVariables(&allAttributes);
            inputVaryings    = GetShaderVariables(sh::GetInputVaryings(compilerHandle));
            // TODO(jmadill): Figure out why we only sort in the FS, and if we need to.
            std::sort(inputVaryings.begin(), inputVaryings.end(), CompareShaderVar);
            activeOutputVariables =
                GetActiveShaderVariables(sh::GetOutputVariables(compilerHandle));
            hasDiscard              = sh::HasDiscardInFragmentShader(compilerHandle);
            enablesPerSampleShading = sh::EnablesPerSampleShading(compilerHandle);
            advancedBlendEquations =
                gl::BlendEquationBitSet(sh::GetAdvancedBlendEquations(compilerHandle));
            break;
        }
        case gl::ShaderType::Geometry:
        {
            inputVaryings  = GetShaderVariables(sh::GetInputVaryings(compilerHandle));
            outputVaryings = GetShaderVariables(sh::GetOutputVaryings(compilerHandle));

            if (sh::HasValidGeometryShaderInputPrimitiveType(compilerHandle))
            {
                geometryShaderInputPrimitiveType = gl::FromGLenum<gl::PrimitiveMode>(
                    sh::GetGeometryShaderInputPrimitiveType(compilerHandle));
            }
            if (sh::HasValidGeometryShaderOutputPrimitiveType(compilerHandle))
            {
                geometryShaderOutputPrimitiveType = gl::FromGLenum<gl::PrimitiveMode>(
                    sh::GetGeometryShaderOutputPrimitiveType(compilerHandle));
            }
            if (sh::HasValidGeometryShaderMaxVertices(compilerHandle))
            {
                geometryShaderMaxVertices = sh::GetGeometryShaderMaxVertices(compilerHandle);
            }
            geometryShaderInvocations = sh::GetGeometryShaderInvocations(compilerHandle);
            break;
        }
        case gl::ShaderType::TessControl:
        {
            inputVaryings             = GetShaderVariables(sh::GetInputVaryings(compilerHandle));
            outputVaryings            = GetShaderVariables(sh::GetOutputVaryings(compilerHandle));
            tessControlShaderVertices = sh::GetTessControlShaderVertices(compilerHandle);
            break;
        }
        case gl::ShaderType::TessEvaluation:
        {
            inputVaryings  = GetShaderVariables(sh::GetInputVaryings(compilerHandle));
            outputVaryings = GetShaderVariables(sh::GetOutputVaryings(compilerHandle));
            if (sh::HasValidTessGenMode(compilerHandle))
            {
                tessGenMode = sh::GetTessGenMode(compilerHandle);
            }
            if (sh::HasValidTessGenSpacing(compilerHandle))
            {
                tessGenSpacing = sh::GetTessGenSpacing(compilerHandle);
            }
            if (sh::HasValidTessGenVertexOrder(compilerHandle))
            {
                tessGenVertexOrder = sh::GetTessGenVertexOrder(compilerHandle);
            }
            if (sh::HasValidTessGenPointMode(compilerHandle))
            {
                tessGenPointMode = sh::GetTessGenPointMode(compilerHandle);
            }
            break;
        }

        default:
            UNREACHABLE();
    }
}

void CompiledShaderState::serialize(gl::BinaryOutputStream &stream) const
{
    stream.writeInt(shaderVersion);

    stream.writeInt(uniforms.size());
    for (const sh::ShaderVariable &shaderVariable : uniforms)
    {
        WriteShaderVar(&stream, shaderVariable);
    }

    stream.writeInt(uniformBlocks.size());
    for (const sh::InterfaceBlock &interfaceBlock : uniformBlocks)
    {
        WriteShInterfaceBlock(&stream, interfaceBlock);
    }

    stream.writeInt(shaderStorageBlocks.size());
    for (const sh::InterfaceBlock &interfaceBlock : shaderStorageBlocks)
    {
        WriteShInterfaceBlock(&stream, interfaceBlock);
    }

    stream.writeInt(specConstUsageBits.bits());

    switch (shaderType)
    {
        case gl::ShaderType::Compute:
        {
            stream.writeInt(allAttributes.size());
            for (const sh::ShaderVariable &shaderVariable : allAttributes)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(activeAttributes.size());
            for (const sh::ShaderVariable &shaderVariable : activeAttributes)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(localSize[0]);
            stream.writeInt(localSize[1]);
            stream.writeInt(localSize[2]);
            break;
        }

        case gl::ShaderType::Vertex:
        {
            stream.writeInt(outputVaryings.size());
            for (const sh::ShaderVariable &shaderVariable : outputVaryings)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(allAttributes.size());
            for (const sh::ShaderVariable &shaderVariable : allAttributes)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(activeAttributes.size());
            for (const sh::ShaderVariable &shaderVariable : activeAttributes)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeBool(hasClipDistance);
            stream.writeInt(numViews);
            break;
        }
        case gl::ShaderType::Fragment:
        {
            stream.writeInt(inputVaryings.size());
            for (const sh::ShaderVariable &shaderVariable : inputVaryings)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(activeOutputVariables.size());
            for (const sh::ShaderVariable &shaderVariable : activeOutputVariables)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeBool(hasDiscard);
            stream.writeBool(enablesPerSampleShading);
            stream.writeInt(advancedBlendEquations.bits());
            break;
        }
        case gl::ShaderType::Geometry:
        {
            bool valid;

            stream.writeInt(inputVaryings.size());
            for (const sh::ShaderVariable &shaderVariable : inputVaryings)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(outputVaryings.size());
            for (const sh::ShaderVariable &shaderVariable : outputVaryings)
            {
                WriteShaderVar(&stream, shaderVariable);
            }

            valid = (bool)geometryShaderInputPrimitiveType.valid();
            stream.writeBool(valid);
            if (valid)
            {
                unsigned char value = (unsigned char)geometryShaderInputPrimitiveType.value();
                stream.writeBytes(&value, 1);
            }
            valid = (bool)geometryShaderOutputPrimitiveType.valid();
            stream.writeBool(valid);
            if (valid)
            {
                unsigned char value = (unsigned char)geometryShaderOutputPrimitiveType.value();
                stream.writeBytes(&value, 1);
            }
            valid = geometryShaderMaxVertices.valid();
            stream.writeBool(valid);
            if (valid)
            {
                int value = (int)geometryShaderMaxVertices.value();
                stream.writeInt(value);
            }

            stream.writeInt(geometryShaderInvocations);
            break;
        }
        case gl::ShaderType::TessControl:
        {
            stream.writeInt(inputVaryings.size());
            for (const sh::ShaderVariable &shaderVariable : inputVaryings)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(outputVaryings.size());
            for (const sh::ShaderVariable &shaderVariable : outputVaryings)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(tessControlShaderVertices);
            break;
        }
        case gl::ShaderType::TessEvaluation:
        {
            unsigned int value;

            stream.writeInt(inputVaryings.size());
            for (const sh::ShaderVariable &shaderVariable : inputVaryings)
            {
                WriteShaderVar(&stream, shaderVariable);
            }
            stream.writeInt(outputVaryings.size());
            for (const sh::ShaderVariable &shaderVariable : outputVaryings)
            {
                WriteShaderVar(&stream, shaderVariable);
            }

            value = (unsigned int)(tessGenMode);
            stream.writeInt(value);

            value = (unsigned int)tessGenSpacing;
            stream.writeInt(value);

            value = (unsigned int)tessGenVertexOrder;
            stream.writeInt(value);

            value = (unsigned int)tessGenPointMode;
            stream.writeInt(value);
            break;
        }
        default:
            UNREACHABLE();
    }

    stream.writeIntVector(compiledBinary);
}

void CompiledShaderState::deserialize(gl::BinaryInputStream &stream)
{
    stream.readInt(&shaderVersion);

    size_t size;
    size = stream.readInt<size_t>();
    uniforms.resize(size);
    for (sh::ShaderVariable &shaderVariable : uniforms)
    {
        LoadShaderVar(&stream, &shaderVariable);
    }

    size = stream.readInt<size_t>();
    uniformBlocks.resize(size);
    for (sh::InterfaceBlock &interfaceBlock : uniformBlocks)
    {
        LoadShInterfaceBlock(&stream, &interfaceBlock);
    }

    size = stream.readInt<size_t>();
    shaderStorageBlocks.resize(size);
    for (sh::InterfaceBlock &interfaceBlock : shaderStorageBlocks)
    {
        LoadShInterfaceBlock(&stream, &interfaceBlock);
    }

    specConstUsageBits = SpecConstUsageBits(stream.readInt<uint32_t>());

    switch (shaderType)
    {
        case gl::ShaderType::Compute:
        {
            size = stream.readInt<size_t>();
            allAttributes.resize(size);
            for (sh::ShaderVariable &shaderVariable : allAttributes)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            size = stream.readInt<size_t>();
            activeAttributes.resize(size);
            for (sh::ShaderVariable &shaderVariable : activeAttributes)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            stream.readInt(&localSize[0]);
            stream.readInt(&localSize[1]);
            stream.readInt(&localSize[2]);
            break;
        }
        case gl::ShaderType::Vertex:
        {
            size = stream.readInt<size_t>();
            outputVaryings.resize(size);
            for (sh::ShaderVariable &shaderVariable : outputVaryings)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            size = stream.readInt<size_t>();
            allAttributes.resize(size);
            for (sh::ShaderVariable &shaderVariable : allAttributes)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            size = stream.readInt<size_t>();
            activeAttributes.resize(size);
            for (sh::ShaderVariable &shaderVariable : activeAttributes)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            stream.readBool(&hasClipDistance);
            stream.readInt(&numViews);
            break;
        }
        case gl::ShaderType::Fragment:
        {
            size = stream.readInt<size_t>();
            inputVaryings.resize(size);
            for (sh::ShaderVariable &shaderVariable : inputVaryings)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            size = stream.readInt<size_t>();
            activeOutputVariables.resize(size);
            for (sh::ShaderVariable &shaderVariable : activeOutputVariables)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            stream.readBool(&hasDiscard);
            stream.readBool(&enablesPerSampleShading);
            int advancedBlendEquationBits;
            stream.readInt(&advancedBlendEquationBits);
            advancedBlendEquations = gl::BlendEquationBitSet(advancedBlendEquationBits);
            break;
        }
        case gl::ShaderType::Geometry:
        {
            bool valid;

            size = stream.readInt<size_t>();
            inputVaryings.resize(size);
            for (sh::ShaderVariable &shaderVariable : inputVaryings)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            size = stream.readInt<size_t>();
            outputVaryings.resize(size);
            for (sh::ShaderVariable &shaderVariable : outputVaryings)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }

            stream.readBool(&valid);
            if (valid)
            {
                unsigned char value;
                stream.readBytes(&value, 1);
                geometryShaderInputPrimitiveType = static_cast<gl::PrimitiveMode>(value);
            }
            else
            {
                geometryShaderInputPrimitiveType.reset();
            }

            stream.readBool(&valid);
            if (valid)
            {
                unsigned char value;
                stream.readBytes(&value, 1);
                geometryShaderOutputPrimitiveType = static_cast<gl::PrimitiveMode>(value);
            }
            else
            {
                geometryShaderOutputPrimitiveType.reset();
            }

            stream.readBool(&valid);
            if (valid)
            {
                int value;
                stream.readInt(&value);
                geometryShaderMaxVertices = static_cast<GLint>(value);
            }
            else
            {
                geometryShaderMaxVertices.reset();
            }

            stream.readInt(&geometryShaderInvocations);
            break;
        }
        case gl::ShaderType::TessControl:
        {
            size = stream.readInt<size_t>();
            inputVaryings.resize(size);
            for (sh::ShaderVariable &shaderVariable : inputVaryings)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            size = stream.readInt<size_t>();
            outputVaryings.resize(size);
            for (sh::ShaderVariable &shaderVariable : outputVaryings)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            stream.readInt(&tessControlShaderVertices);
            break;
        }
        case gl::ShaderType::TessEvaluation:
        {
            unsigned int value;

            size = stream.readInt<size_t>();
            inputVaryings.resize(size);
            for (sh::ShaderVariable &shaderVariable : inputVaryings)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }
            size = stream.readInt<size_t>();
            outputVaryings.resize(size);
            for (sh::ShaderVariable &shaderVariable : outputVaryings)
            {
                LoadShaderVar(&stream, &shaderVariable);
            }

            stream.readInt(&value);
            tessGenMode = (GLenum)value;

            stream.readInt(&value);
            tessGenSpacing = (GLenum)value;

            stream.readInt(&value);
            tessGenVertexOrder = (GLenum)value;

            stream.readInt(&value);
            tessGenPointMode = (GLenum)value;
            break;
        }
        default:
            UNREACHABLE();
    }

    stream.readIntVector<unsigned int>(&compiledBinary);
}
}  // namespace gl
