Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
pablode committed Jan 18, 2025
1 parent 19cd2de commit a6b2371
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 26 deletions.
78 changes: 58 additions & 20 deletions src/cgpu/impl/Cgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,13 +1218,16 @@ namespace gtl
return true;
}

std::mutex s_setMutex;

static bool cgpuCreatePipelineDescriptorSet(CgpuIDevice* idevice,
CgpuIShader* ishader,
VkShaderStageFlags stageFlags,
VkDescriptorPool& descriptorPool,
VkDescriptorSetLayout& descriptorSetLayout,
std::vector<VkDescriptorSetLayoutBinding>& descriptorSetLayoutBindings)
{
s_setMutex.lock();
const CgpuShaderReflection* shaderReflection = &ishader->reflection;
size_t bindingCount = shaderReflection->bindings.size();

Expand All @@ -1233,12 +1236,12 @@ namespace gtl

for (uint32_t i = 0; i < bindingCount; i++)
{
const CgpuShaderReflectionBinding* bindingReflection = &shaderReflection->bindings[i];
const CgpuShaderReflectionBinding& bindingReflection = shaderReflection->bindings[i];

VkDescriptorSetLayoutBinding layoutBinding = {
.binding = bindingReflection->binding,
.descriptorType = (VkDescriptorType) bindingReflection->descriptorType,
.descriptorCount = bindingReflection->count,
.binding = bindingReflection.binding,
.descriptorType = (VkDescriptorType) bindingReflection.descriptorType,
.descriptorCount = bindingReflection.count,
.stageFlags = stageFlags,
.pImmutableSamplers = nullptr,
};
Expand Down Expand Up @@ -1268,6 +1271,7 @@ namespace gtl
nullptr,
&descriptorSetLayout
);
s_setMutex.unlock();

if (result != VK_SUCCESS) {
CGPU_RETURN_ERROR("failed to create descriptor set layout");
Expand Down Expand Up @@ -1353,12 +1357,12 @@ namespace gtl
return true;
}

std::mutex s_shaderMutex;

static bool cgpuCreateShader(CgpuIDevice* idevice,
CgpuShaderCreateInfo createInfo,
CgpuShader shader)
const CgpuShaderCreateInfo& createInfo,
CgpuIShader* ishader)
{
CGPU_RESOLVE_SHADER({ shader }, ishader);

VkShaderModuleCreateInfo shaderModuleCreateInfo = {
.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
.pNext = nullptr,
Expand All @@ -1367,12 +1371,14 @@ namespace gtl
.pCode = (uint32_t*) createInfo.source,
};

s_shaderMutex.lock();
VkResult result = idevice->table.vkCreateShaderModule(
idevice->logicalDevice,
&shaderModuleCreateInfo,
nullptr,
&ishader->module
);
s_shaderMutex.unlock();
if (result != VK_SUCCESS) {
CGPU_RETURN_ERROR("failed to create shader module");
}
Expand All @@ -1390,7 +1396,7 @@ namespace gtl
CGPU_RETURN_ERROR("failed to reflect shader");
}

if (createInfo.stageFlags != CGPU_SHADER_STAGE_FLAG_COMPUTE)
if (createInfo.stageFlags == CGPU_SHADER_STAGE_FLAG_CLOSEST_HIT || createInfo.stageFlags == CGPU_SHADER_STAGE_FLAG_ANY_HIT)
{
if (!cgpuCreateRtPipelineLibrary(idevice, ishader, CGPU_RT_PIPELINE_ACCESS_FLAGS))
{
Expand All @@ -1410,7 +1416,9 @@ namespace gtl

shader->handle = iinstance->ishaderStore.allocate();

if (!cgpuCreateShader(idevice, createInfo, *shader))
CGPU_RESOLVE_SHADER(*shader, ishader);

if (!cgpuCreateShader(idevice, createInfo, ishader))
{
iinstance->ishaderStore.free(shader->handle);
return false;
Expand All @@ -1431,12 +1439,22 @@ namespace gtl
shaders[i].handle = iinstance->ishaderStore.allocate();
}

std::vector<CgpuIShader*> ishaders;
ishaders.resize(shaderCount, nullptr);

for (uint32_t j = 0; j < 16; j++) // TODO: IMPORTANT design flaw..
for (uint32_t i = 0; i < shaderCount; i++)
{
CGPU_RESOLVE_SHADER(shaders[i], ishader);
ishaders[i] = ishader;
}

std::atomic<int> errorCount = false;

#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < int(shaderCount); i++)
{
if (!cgpuCreateShader(idevice, createInfos[i], shaders[i]))
if (!cgpuCreateShader(idevice, createInfos[i], ishaders[i]))
{
errorCount++;
}
Expand Down Expand Up @@ -2156,6 +2174,33 @@ namespace gtl

memset(ipipeline, 0, sizeof(CgpuIPipeline));

// Gather stages
std::vector<VkPipelineShaderStageCreateInfo> stages;
stages.reserve(1 + createInfo.missShaderCount);

auto pushStage = [&stages](VkShaderStageFlagBits stage, VkShaderModule module) {
VkPipelineShaderStageCreateInfo stageCreateInfo = {
.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
.pNext = nullptr,
.flags = 0,
.stage = stage,
.module = module,
.pName = CGPU_SHADER_ENTRY_POINT,
.pSpecializationInfo = nullptr,
};
stages.push_back(stageCreateInfo);
};

CGPU_RESOLVE_SHADER(createInfo.rgenShader, irgenShader);
pushStage(VK_SHADER_STAGE_RAYGEN_BIT_KHR, irgenShader->module);

for (uint32_t i = 0; i < createInfo.missShaderCount; i++)
{
CGPU_RESOLVE_SHADER(createInfo.missShaders[i], imissShader);
assert(imissShader->module);
pushStage(VK_SHADER_STAGE_MISS_BIT_KHR, imissShader->module);
}

// Gather groups
size_t groupCount = 1/*rgen*/ + createInfo.missShaderCount + createInfo.hitGroupCount;
std::vector<VkRayTracingShaderGroupCreateInfoKHR> groups(groupCount);
Expand Down Expand Up @@ -2209,8 +2254,6 @@ namespace gtl
}

// Create descriptor and pipeline layout.
CGPU_RESOLVE_SHADER(createInfo.rgenShader, irgenShader);

if (!cgpuCreatePipelineDescriptors(idevice, irgenShader, CGPU_RT_PIPELINE_ACCESS_FLAGS, ipipeline))
{
goto cleanup_fail;
Expand All @@ -2230,11 +2273,6 @@ namespace gtl
std::vector<VkPipeline> libraries;
libraries.reserve(groupCount);

libraries.push_back(irgenShader->pipelineLibrary.pipeline);
for (uint32_t i = 0; i < createInfo.missShaderCount; i++)
{
libraries.push_back(getShaderPipelineHandle(createInfo.missShaders[i]));
}
for (uint32_t i = 0; i < createInfo.hitGroupCount; i++)
{
CgpuShader closestHitShader = createInfo.hitGroups[i].closestHitShader;
Expand Down Expand Up @@ -2281,8 +2319,8 @@ namespace gtl
.sType = VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR,
.pNext = nullptr,
.flags = flags,
.stageCount = 0,
.pStages = nullptr,
.stageCount = (uint32_t) stages.size(),
.pStages = stages.data(),
.groupCount = (uint32_t) groups.size(),
.pGroups = groups.data(),
.maxPipelineRayRecursionDepth = 1,
Expand Down
20 changes: 14 additions & 6 deletions src/cgpu/impl/ShaderReflection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ namespace gtl
}
}

fprintf(stderr, "maxRayPayloadSize: %u\n", reflection->maxRayPayloadSize);
fprintf(stderr, "maxRayHitAttributeSize: %u\n", reflection->maxRayHitAttributeSize);

// TODO: closest and any hit shaders have different values -> crash on NV..
reflection->maxRayPayloadSize = 80;
reflection->maxRayHitAttributeSize = 8;


if (spvReflectEnumerateDescriptorBindings(&shaderModule, &bindingCount, nullptr) != SPV_REFLECT_RESULT_SUCCESS)
{
goto fail;
Expand All @@ -154,15 +162,15 @@ namespace gtl
{
const SpvReflectDescriptorBinding* srcBinding = bindings[i];

CgpuShaderReflectionBinding* dstBinding = &reflection->bindings[i];
dstBinding->binding = srcBinding->binding;
dstBinding->count = srcBinding->count;
dstBinding->descriptorType = (int)srcBinding->descriptor_type;
CgpuShaderReflectionBinding& dstBinding = reflection->bindings[i];
dstBinding.binding = srcBinding->binding;
dstBinding.count = srcBinding->count;
dstBinding.descriptorType = (int)srcBinding->descriptor_type;
// Unfortunately SPIRV-Reflect lacks the functionality to detect read accesses:
// https://github.com/KhronosGroup/SPIRV-Reflect/issues/99
dstBinding->readAccess = srcBinding->accessed;
dstBinding.readAccess = srcBinding->accessed;
const SpvReflectTypeDescription* typeDescription = srcBinding->type_description;
dstBinding->writeAccess = srcBinding->accessed && ~(typeDescription->decoration_flags & SPV_REFLECT_DECORATION_NON_WRITABLE);
dstBinding.writeAccess = srcBinding->accessed && ~(typeDescription->decoration_flags & SPV_REFLECT_DECORATION_NON_WRITABLE);
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/gi/impl/Gi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,11 @@ namespace gtl
scene->dirtyFlags |= GiSceneDirtyFlags::DirtyFramebuffer | GiSceneDirtyFlags::DirtyBvh; // SBT
}

if (!scene->shaderCache)
{
return GiStatus::Error;
}

if (!scene->bvh || bool(scene->dirtyFlags & GiSceneDirtyFlags::DirtyBvh))
{
if (scene->bvh) _giDestroyBvh(scene->bvh);
Expand Down

0 comments on commit a6b2371

Please sign in to comment.