Merge "[gfxstream] Add dispatcher validity checks" into main
diff --git a/src/gfxstream/codegen/scripts/cereal/common/codegen.py b/src/gfxstream/codegen/scripts/cereal/common/codegen.py
index 91ccf2e..3ffc30c 100644
--- a/src/gfxstream/codegen/scripts/cereal/common/codegen.py
+++ b/src/gfxstream/codegen/scripts/cereal/common/codegen.py
@@ -629,7 +629,9 @@
     def generalLengthAccessGuard(self, vulkanType, parentVarName="parent"):
         return self.makeLengthAccess(vulkanType, parentVarName)[1]
 
-    def vkApiCall(self, api, customPrefix="", globalStatePrefix="", customParameters=None, checkForDeviceLost=False, checkForOutOfMemory=False):
+    def vkApiCall(self, api, customPrefix="", globalStatePrefix="",
+                  customParameters=None, checkForDeviceLost=False,
+                  checkForOutOfMemory=False, checkDispatcher=None):
         callLhs = None
 
         retTypeName = api.getRetTypeExpr()
@@ -637,9 +639,22 @@
 
         if retTypeName != "void":
             retVar = api.getRetVarExpr()
-            self.stmt("%s %s = (%s)0" % (retTypeName, retVar, retTypeName))
+            defaultReturn = "(%s)0" % retTypeName
+            if retTypeName == "VkResult":
+                # TODO: return a valid error code based on the call
+                # This is used to handle invalid dispatcher and snapshot states
+                deviceLostFunctions = ["vkQueueSubmit",
+                                       "vkQueueWaitIdle",
+                                       "vkWaitForFences"]
+                defaultReturn = "VK_ERROR_OUT_OF_HOST_MEMORY"
+                if api in deviceLostFunctions:
+                    defaultReturn = "VK_ERROR_DEVICE_LOST"
+            self.stmt("%s %s = %s" % (retTypeName, retVar, defaultReturn))
             callLhs = retVar
 
+        if (checkDispatcher):
+            self.beginIf(checkDispatcher)
+
         if customParameters is None:
             self.funcCall(
             callLhs, customPrefix + api.name, [p.paramName for p in api.parameters])
@@ -647,6 +662,9 @@
             self.funcCall(
                 callLhs, customPrefix + api.name, customParameters)
 
+        if (checkDispatcher):
+            self.endIf()
+
         if retTypeName == "VkResult" and checkForDeviceLost:
             self.stmt("if ((%s) == VK_ERROR_DEVICE_LOST) %sDeviceLost()" % (callLhs, globalStatePrefix))
 
diff --git a/src/gfxstream/codegen/scripts/cereal/decoder.py b/src/gfxstream/codegen/scripts/cereal/decoder.py
index 7e931a1..674fd60 100644
--- a/src/gfxstream/codegen/scripts/cereal/decoder.py
+++ b/src/gfxstream/codegen/scripts/cereal/decoder.py
@@ -30,6 +30,8 @@
 ]
 
 GLOBAL_COMMANDS_WITHOUT_DISPATCH = [
+    "vkCreateInstance",
+    "vkEnumerateInstanceVersion",
     "vkEnumerateInstanceExtensionProperties",
     "vkEnumerateInstanceLayerProperties",
 ]
@@ -216,6 +218,10 @@
         cgen.stmt("auto vk = dispatch_%s(%s)" %
                   (param.typeName, param.paramName))
         cgen.stmt("// End manual dispatchable handle unboxing for %s" % param.paramName)
+    else:
+        # Still need to check dispatcher validity to handle threads with fatal errors
+        cgen.stmt("auto vk = dispatch_%s(%s)" %
+                  (param.typeName, param.paramName))
 
 
 def emit_transform(typeInfo, param, cgen, variant="tohost"):
@@ -344,12 +350,14 @@
             cgen.stmt("m_state->lock()")
 
     whichDispatch = "vk->"
+    checkDispatcher = "CC_LIKELY(vk)"
     if api.name in GLOBAL_COMMANDS_WITHOUT_DISPATCH:
         whichDispatch = "m_vk->"
+        checkDispatcher = None
 
     cgen.vkApiCall(api, customPrefix=whichDispatch, customParameters=customParams, \
         globalStatePrefix=global_state_prefix, checkForDeviceLost=True,
-        checkForOutOfMemory=True)
+        checkForOutOfMemory=True, checkDispatcher=checkDispatcher)
 
     if api.name in driver_workarounds_global_lock_apis:
         if not delay:
@@ -366,7 +374,7 @@
     coreCustomParams = list(map(lambda p: p.paramName, api.parameters))
 
     if delay:
-        cgen.line("std::function<void()> delayed_remove_callback = [%s]() {" % ", ".join(coreCustomParams))
+        cgen.line("std::function<void()> delayed_remove_callback = [vk, %s]() {" % ", ".join(coreCustomParams))
         cgen.stmt("auto m_state = VkDecoderGlobalState::get()")
         customParams = ["nullptr", "nullptr"] + coreCustomParams
     else:
@@ -374,9 +382,13 @@
 
     if context:
         customParams += ["context"]
+
+    checkDispatcher = "CC_LIKELY(vk)"
+    if api.name in GLOBAL_COMMANDS_WITHOUT_DISPATCH:
+        checkDispatcher = None
     cgen.vkApiCall(api, customPrefix=global_state_prefix, \
         customParameters=customParams, globalStatePrefix=global_state_prefix, \
-        checkForDeviceLost=True, checkForOutOfMemory=True)
+        checkForDeviceLost=True, checkForOutOfMemory=True, checkDispatcher=checkDispatcher)
 
     if delay:
         cgen.line("};")
@@ -835,6 +847,11 @@
     def onBegin(self,):
         self.module.appendImpl(
             "#define MAX_PACKET_LENGTH %s\n" % MAX_PACKET_LENGTH)
+        self.module.appendImpl(
+            "#define CC_LIKELY(exp)    (__builtin_expect( !!(exp), true ))\n")
+        self.module.appendImpl(
+            "#define CC_UNLIKELY(exp)  (__builtin_expect( !!(exp), false ))\n")
+
         self.module.appendHeader(decoder_decl_preamble)
         self.module.appendImpl(decoder_impl_preamble)
 
diff --git a/src/gfxstream/codegen/scripts/cereal/subdecode.py b/src/gfxstream/codegen/scripts/cereal/subdecode.py
index 6780aa5..47c8e58 100644
--- a/src/gfxstream/codegen/scripts/cereal/subdecode.py
+++ b/src/gfxstream/codegen/scripts/cereal/subdecode.py
@@ -261,7 +261,7 @@
 
     cgen.vkApiCall(api, customPrefix="vk->", customParameters=customParams,
                     checkForDeviceLost=True, globalStatePrefix=global_state_prefix,
-                    checkForOutOfMemory=True)
+                    checkForOutOfMemory=True, checkDispatcher="CC_LIKELY(vk)")
 
     if api.name in driver_workarounds_global_lock_apis:
         cgen.stmt("unlock()")
@@ -274,7 +274,7 @@
         customParams += ["context"];
     cgen.vkApiCall(api, customPrefix=global_state_prefix,
                    customParameters=customParams, checkForDeviceLost=True,
-                   checkForOutOfMemory=True, globalStatePrefix=global_state_prefix)
+                   checkForOutOfMemory=True, globalStatePrefix=global_state_prefix, checkDispatcher="CC_LIKELY(vk)")
 
 
 def emit_default_decoding(typeInfo, api, cgen):
@@ -331,6 +331,10 @@
 
         self.module.appendImpl(
             "#define MAX_PACKET_LENGTH %s\n" % MAX_PACKET_LENGTH)
+        self.module.appendImpl(
+            "#define CC_LIKELY(exp)    (__builtin_expect( !!(exp), true ))\n")
+        self.module.appendImpl(
+            "#define CC_UNLIKELY(exp)  (__builtin_expect( !!(exp), false ))\n")
 
         self.module.appendImpl(
             "size_t subDecode(VulkanMemReadingStream* readStream, VulkanDispatch* vk, void* boxed_dispatchHandle, void* dispatchHandle, VkDeviceSize subDecodeDataSize, const void* pSubDecodeData, const VkDecoderContext& context)\n")