跳到主要内容

shader object用法

· 阅读需 8 分钟

VK_EXT_shader_object 是 Vulkan 的一个扩展(在 Vulkan 1.3+ 常见 GPU 上支持),其目标是简化 shader 使用流程,提供一种比传统 pipeline 更轻量、更灵活的方式来管理和绑定着色器,不再需要创建复杂的 graphics/compute pipeline 对象

✅ 核心作用

传统 Vulkan:

vkCreateShaderModule
vkCreatePipelineLayout
vkCreateGraphicsPipelines
vkBeginCommandBuffer
vkCmdBindPipeline
vkCmdDraw
vkEndCommandBuffer

使用 VK_EXT_shader_object 后:

vkCreateShadersEXT
vkBeginCommandBuffer
vkCmdBeginRendering
vkCmdBindShadersEXT
vkCmdDraw
vkCmdEndRendering
vkEndCommandBuffer

🚀 主要优势

  • 免 pipeline:不再需要提前组合多个 shader 成 pipeline
  • shader 热插拔:可以动态替换单个阶段(如 fragment shader)
  • 灵活:非常适合动态渲染系统、调试工具、可编程 pipeline

🧱 关键结构和函数

1. 创建 Shader Object

VkShaderCreateInfoEXT shaderInfo = {
.sType = VK_STRUCTURE_TYPE_SHADER_CREATE_INFO_EXT,
.stage = VK_SHADER_STAGE_VERTEX_BIT, // 阶段
.nextStage = VK_SHADER_STAGE_FRAGMENT_BIT, // 可用于管线优化(可设为 0)
.codeType = VK_SHADER_CODE_TYPE_SPIRV_EXT, // SPIR-V 类型
.pCode = spirvData,
.codeSize = spirvSize,
.pName = "main", // entry point
// 可以设置 descriptor layouts 等(类似 pipeline layout)
};

VkShaderEXT shader;
vkCreateShadersEXT(device, 1, &shaderInfo, nullptr, &shader);

2. 绑定 Shader(代替 vkCmdBindPipeline

vkCmdBindShadersEXT(
commandBuffer,
VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT,
0, // firstBinding
&vertexShader, // pShaders (按顺序提供每个阶段的 shader)
);

3. 设置 pipeline 状态(dynamic state)

Shader Object 不绑定 pipeline,所以你必须通过命令设置 pipeline 动态状态(rasterization, blend, topology 等):

vkCmdSetRasterizerDiscardEnable(cmd, VK_FALSE);
vkCmdSetDepthTestEnable(cmd, VK_TRUE);
vkCmdSetPrimitiveTopology(cmd, VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST);
vkCmdSetViewport(cmd, 0, 1, &viewport);
vkCmdSetScissor(cmd, 0, 1, &scissor);

所有 VK_DYNAMIC_STATE_* 都需要设置,否则行为未定义!

4. 资源绑定(Descriptor Sets)

仍然使用传统 vkCmdBindDescriptorSets 方式来绑定资源(uniforms, textures 等):

vkCmdBindDescriptorSets(
cmd,
VK_PIPELINE_BIND_POINT_GRAPHICS,
pipelineLayout,
0, 1, &descriptorSet,
0, nullptr
);

5. 销毁 Shader

vkDestroyShaderEXT(device, shader, nullptr);

📦示例代码

#include <vulkan/vulkan.h>
#include <cassert>
#include <cstring>
#include <vector>
#include <iostream>

// 这里假定你已有方式加载或嵌入 SPIR-V 二进制代码
// 这里使用伪数据,实际使用时请加载有效的 SPIR-V 数据
std::vector<uint32_t> vertexShaderSPIRV = {/* ... Vertex SPIR-V 字节码 ... */};
std::vector<uint32_t> fragmentShaderSPIRV = {/* ... Fragment SPIR-V 字节码 ... */};

// 简单的错误处理宏
#define VK_CHECK(result) assert((result) == VK_SUCCESS)

// 伪函数:创建 Vulkan instance、device、command pool 等
// 这里只给出简略流程,实际工程需要完整初始化代码
struct VulkanContext {
VkInstance instance;
VkPhysicalDevice physicalDevice;
VkDevice device;
VkQueue graphicsQueue;
VkCommandPool commandPool;
// 其它成员如交换链、呈现目标 imageView 等(本示例中用作 dynamic rendering 目标)
};
VulkanContext initVulkan() {
VulkanContext ctx = {};
// ── 创建 instance (确保启用 VK_KHR_dynamic_rendering 扩展) ──
const char* instanceExtensions[] = {
"VK_KHR_surface",
"VK_KHR_win32_surface", // 例如 Windows 平台
// 其它平台请选择对应扩展
};

VkApplicationInfo appInfo = {};
appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
appInfo.pApplicationName = "ShaderObjectTriangle";
appInfo.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
appInfo.pEngineName = "None";
appInfo.engineVersion = VK_MAKE_VERSION(1, 0, 0);
appInfo.apiVersion = VK_API_VERSION_1_3; // 或 Vulkan 1.2 加扩展

VkInstanceCreateInfo instCreateInfo = {};
instCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
instCreateInfo.pApplicationInfo = &appInfo;
instCreateInfo.enabledExtensionCount = sizeof(instanceExtensions) / sizeof(instanceExtensions[0]);
instCreateInfo.ppEnabledExtensionNames = instanceExtensions;
VK_CHECK(vkCreateInstance(&instCreateInfo, nullptr, &ctx.instance));

// ── 选择物理设备,并创建逻辑设备 ──
uint32_t gpuCount = 0;
VK_CHECK(vkEnumeratePhysicalDevices(ctx.instance, &gpuCount, nullptr));
assert(gpuCount > 0);
std::vector<VkPhysicalDevice> devices(gpuCount);
VK_CHECK(vkEnumeratePhysicalDevices(ctx.instance, &gpuCount, devices.data()));
ctx.physicalDevice = devices[0]; // 简单选择第一个

// 启用需要的 device 扩展:动态渲染 & shader object
const char* deviceExtensions[] = {
"VK_KHR_dynamic_rendering",
"VK_EXT_shader_object"
};

float queuePriority = 1.0f;
VkDeviceQueueCreateInfo queueCreateInfo = {};
queueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
queueCreateInfo.queueFamilyIndex = 0; // 此处假定 family 0 支持图形
queueCreateInfo.queueCount = 1;
queueCreateInfo.pQueuePriorities = &queuePriority;

VkPhysicalDeviceFeatures deviceFeatures = {};
VkDeviceCreateInfo devCreateInfo = {};
devCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
devCreateInfo.queueCreateInfoCount = 1;
devCreateInfo.pQueueCreateInfos = &queueCreateInfo;
devCreateInfo.enabledExtensionCount = 2;
devCreateInfo.ppEnabledExtensionNames = deviceExtensions;
devCreateInfo.pEnabledFeatures = &deviceFeatures;
VK_CHECK(vkCreateDevice(ctx.physicalDevice, &devCreateInfo, nullptr, &ctx.device));

vkGetDeviceQueue(ctx.device, 0, 0, &ctx.graphicsQueue);

// ── 创建 command pool ──
VkCommandPoolCreateInfo poolInfo = {};
poolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
poolInfo.queueFamilyIndex = 0;
poolInfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
VK_CHECK(vkCreateCommandPool(ctx.device, &poolInfo, nullptr, &ctx.commandPool));

return ctx;
}

// 主函数
int main() {
VulkanContext ctx = initVulkan();

// 获取 VK_EXT_shader_object 的函数指针
auto vkCreateShadersEXT = reinterpret_cast<PFN_vkCreateShadersEXT>(
vkGetDeviceProcAddr(ctx.device, "vkCreateShadersEXT"));
auto vkCmdBindShadersEXT = reinterpret_cast<PFN_vkCmdBindShadersEXT>(
vkGetDeviceProcAddr(ctx.device, "vkCmdBindShadersEXT"));
auto vkDestroyShaderEXT = reinterpret_cast<PFN_vkDestroyShaderEXT>(
vkGetDeviceProcAddr(ctx.device, "vkDestroyShaderEXT"));
assert(vkCreateShadersEXT && vkCmdBindShadersEXT && vkDestroyShaderEXT);

// ── 创建 Shader Object ──
VkShaderEXT vertexShader = VK_NULL_HANDLE;
VkShaderEXT fragmentShader = VK_NULL_HANDLE;

// 配置顶点着色器
VkShaderCreateInfoEXT vertexShaderInfo = {};
vertexShaderInfo.sType = VK_STRUCTURE_TYPE_SHADER_CREATE_INFO_EXT;
vertexShaderInfo.stage = VK_SHADER_STAGE_VERTEX_BIT;
vertexShaderInfo.codeType = VK_SHADER_CODE_TYPE_SPIRV_EXT;
vertexShaderInfo.pCode = vertexShaderSPIRV.data();
vertexShaderInfo.codeSize = vertexShaderSPIRV.size() * sizeof(uint32_t);
vertexShaderInfo.pName = "main"; // entry point

VK_CHECK(vkCreateShadersEXT(ctx.device, 1, &vertexShaderInfo, nullptr, &vertexShader));

// 配置片元着色器
VkShaderCreateInfoEXT fragmentShaderInfo = {};
fragmentShaderInfo.sType = VK_STRUCTURE_TYPE_SHADER_CREATE_INFO_EXT;
fragmentShaderInfo.stage = VK_SHADER_STAGE_FRAGMENT_BIT;
fragmentShaderInfo.codeType = VK_SHADER_CODE_TYPE_SPIRV_EXT;
fragmentShaderInfo.pCode = fragmentShaderSPIRV.data();
fragmentShaderInfo.codeSize = fragmentShaderSPIRV.size() * sizeof(uint32_t);
fragmentShaderInfo.pName = "main"; // entry point

VK_CHECK(vkCreateShadersEXT(ctx.device, 1, &fragmentShaderInfo, nullptr, &fragmentShader));

// ── 创建 Command Buffer ──
VkCommandBufferAllocateInfo allocInfo = {};
allocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
allocInfo.commandPool = ctx.commandPool;
allocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
allocInfo.commandBufferCount = 1;
VkCommandBuffer cmdBuffer;
VK_CHECK(vkAllocateCommandBuffers(ctx.device, &allocInfo, &cmdBuffer));

// 开始录制命令
VkCommandBufferBeginInfo beginInfo = {};
beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
VK_CHECK(vkBeginCommandBuffer(cmdBuffer, &beginInfo));

// ── 动态渲染设置 ──
// 假定我们已经创建了一个目标 color image 和 image view(renderTargetView)
// 这里为示例使用伪变量
VkImageView renderTargetView = VK_NULL_HANDLE; // 实际代码中应为有效的 VkImageView
VkExtent2D renderExtent = {800, 600}; // 示例窗口尺寸

VkRenderingAttachmentInfo colorAttachment = {};
colorAttachment.sType = VK_STRUCTURE_TYPE_RENDERING_ATTACHMENT_INFO;
colorAttachment.imageView = renderTargetView;
colorAttachment.imageLayout = VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL;
colorAttachment.loadOp = VK_ATTACHMENT_LOAD_OP_CLEAR;
colorAttachment.storeOp = VK_ATTACHMENT_STORE_OP_STORE;
colorAttachment.clearValue.color.float32[0] = 0.0f;
colorAttachment.clearValue.color.float32[1] = 0.0f;
colorAttachment.clearValue.color.float32[2] = 0.0f;
colorAttachment.clearValue.color.float32[3] = 1.0f;

VkRenderingInfo renderingInfo = {};
renderingInfo.sType = VK_STRUCTURE_TYPE_RENDERING_INFO;
renderingInfo.renderArea.offset = {0, 0};
renderingInfo.renderArea.extent = renderExtent;
renderingInfo.layerCount = 1;
renderingInfo.colorAttachmentCount = 1;
renderingInfo.pColorAttachments = &colorAttachment;
// 本例不使用深度/模板,因此 pDepthAttachment 和 pStencilAttachment 为 nullptr

// 开始动态渲染
vkCmdBeginRendering(cmdBuffer, &renderingInfo);

// ── 绑定 Shader Objects ──
// 为顶点和片元阶段分别绑定 shader object
VkShaderEXT shaders[2] = { vertexShader, fragmentShader };
// stageFlags 指定需要绑定的 shader 阶段(本例为顶点和片元)
VkShaderStageFlags stageFlags = VK_SHADER_STAGE_VERTEX_BIT | VK_SHADER_STAGE_FRAGMENT_BIT;
// firstBinding 可设置为 0
vkCmdBindShadersEXT(cmdBuffer, stageFlags, 0, 2, shaders);

// ── 设置动态状态 ──
VkViewport viewport = {};
viewport.x = 0.0f;
viewport.y = 0.0f;
viewport.width = static_cast<float>(renderExtent.width);
viewport.height = static_cast<float>(renderExtent.height);
viewport.minDepth = 0.0f;
viewport.maxDepth = 1.0f;
vkCmdSetViewport(cmdBuffer, 0, 1, &viewport);

VkRect2D scissor = {};
scissor.offset = {0, 0};
scissor.extent = renderExtent;
vkCmdSetScissor(cmdBuffer, 0, 1, &scissor);

// 此外,你可能需要设置其它动态状态(例如混合、光栅化状态等),视你的渲染需要而定

// ── 绘制三角形 ──
// 此处未绑定 vertex buffers,假设 vertex shader 内部生成顶点数据(使用内置变量或其他技术)
// 最简单的情况直接发起 3 个顶点的 draw 调用
vkCmdDraw(cmdBuffer, 3, 1, 0, 0);

// 结束动态渲染
vkCmdEndRendering(cmdBuffer);

// 结束命令缓冲录制
VK_CHECK(vkEndCommandBuffer(cmdBuffer));

// ── 提交命令缓冲并等待执行完成 ──
VkSubmitInfo submitInfo = {};
submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
submitInfo.commandBufferCount = 1;
submitInfo.pCommandBuffers = &cmdBuffer;
VK_CHECK(vkQueueSubmit(ctx.graphicsQueue, 1, &submitInfo, VK_NULL_HANDLE));
VK_CHECK(vkQueueWaitIdle(ctx.graphicsQueue));

// ── 清理 Shader Objects 及其它资源 ──
vkDestroyShaderEXT(ctx.device, vertexShader, nullptr);
vkDestroyShaderEXT(ctx.device, fragmentShader, nullptr);

// 清理其它 Vulkan 对象(command pool、device、instance 等)
vkDestroyCommandPool(ctx.device, ctx.commandPool, nullptr);
vkDestroyDevice(ctx.device, nullptr);
vkDestroyInstance(ctx.instance, nullptr);

std::cout << "Shader Object Triangle rendered successfully.\n";
return 0;
}