Skip to content

Commit

Permalink
specialization: Fix uniform buffer size check.
Browse files Browse the repository at this point in the history
  • Loading branch information
squidbus committed Jan 11, 2025
1 parent e09cdd9 commit 8951ac9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
38 changes: 30 additions & 8 deletions src/shader_recompiler/specialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ struct VsAttribSpecialization {
AmdGpu::NumberClass num_class{};

auto operator<=>(const VsAttribSpecialization&) const = default;

[[nodiscard]] bool IsCompatible(const VsAttribSpecialization& other) const {
return *this == other;
}
};

struct BufferSpecialization {
Expand All @@ -26,7 +30,7 @@ struct BufferSpecialization {
u8 element_size : 2 = 0;
u32 size = 0;

bool operator==(const BufferSpecialization& other) const {
[[nodiscard]] bool IsCompatible(const BufferSpecialization& other) const {
return stride == other.stride && is_storage == other.is_storage &&
swizzle_enable == other.swizzle_enable &&
(!swizzle_enable ||
Expand All @@ -41,6 +45,10 @@ struct TextureBufferSpecialization {
AmdGpu::NumberConversion num_conversion{};

auto operator<=>(const TextureBufferSpecialization&) const = default;

[[nodiscard]] bool IsCompatible(const TextureBufferSpecialization& other) const {
return *this == other;
}
};

struct ImageSpecialization {
Expand All @@ -51,19 +59,31 @@ struct ImageSpecialization {
AmdGpu::NumberConversion num_conversion{};

auto operator<=>(const ImageSpecialization&) const = default;

[[nodiscard]] bool IsCompatible(const ImageSpecialization& other) const {
return *this == other;
}
};

struct FMaskSpecialization {
u32 width;
u32 height;

auto operator<=>(const FMaskSpecialization&) const = default;

[[nodiscard]] bool IsCompatible(const FMaskSpecialization& other) const {
return *this == other;
}
};

struct SamplerSpecialization {
bool force_unnormalized = false;

auto operator<=>(const SamplerSpecialization&) const = default;

[[nodiscard]] bool IsCompatible(const SamplerSpecialization& other) const {
return *this == other;
}
};

/**
Expand Down Expand Up @@ -179,7 +199,9 @@ struct StageSpecialization {
}
}

bool operator==(const StageSpecialization& other) const {
/// Checks if the permutation this specialization is for can be used in place of 'other'.
/// Note that this operation is not bidirectional.
[[nodiscard]] bool IsCompatible(const StageSpecialization& other) const {
if (start != other.start) {
return false;
}
Expand All @@ -190,7 +212,7 @@ struct StageSpecialization {
return false;
}
for (u32 i = 0; i < vs_attribs.size(); i++) {
if (vs_attribs[i] != other.vs_attribs[i]) {
if (!vs_attribs[i].IsCompatible(other.vs_attribs[i])) {
return false;
}
}
Expand All @@ -202,27 +224,27 @@ struct StageSpecialization {
binding++;
}
for (u32 i = 0; i < buffers.size(); i++) {
if (other.bitset[binding++] && buffers[i] != other.buffers[i]) {
if (other.bitset[binding++] && !buffers[i].IsCompatible(other.buffers[i])) {
return false;
}
}
for (u32 i = 0; i < tex_buffers.size(); i++) {
if (other.bitset[binding++] && tex_buffers[i] != other.tex_buffers[i]) {
if (other.bitset[binding++] && !tex_buffers[i].IsCompatible(other.tex_buffers[i])) {
return false;
}
}
for (u32 i = 0; i < images.size(); i++) {
if (other.bitset[binding++] && images[i] != other.images[i]) {
if (other.bitset[binding++] && !images[i].IsCompatible(other.images[i])) {
return false;
}
}
for (u32 i = 0; i < fmasks.size(); i++) {
if (other.bitset[binding++] && fmasks[i] != other.fmasks[i]) {
if (other.bitset[binding++] && !fmasks[i].IsCompatible(other.fmasks[i])) {
return false;
}
}
for (u32 i = 0; i < samplers.size(); i++) {
if (samplers[i] != other.samplers[i]) {
if (!samplers[i].IsCompatible(other.samplers[i])) {
return false;
}
}
Expand Down
11 changes: 6 additions & 5 deletions src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,20 +517,22 @@ PipelineCache::Result PipelineCache::GetProgram(Stage stage, LogicalStage l_stag
auto start = binding;
const auto module = CompileModule(program->info, runtime_info, params.code, 0, binding);
const auto spec = Shader::StageSpecialization(program->info, runtime_info, profile, start);
const auto fetch_shader = spec.fetch_shader_data;
program->AddPermut(module, std::move(spec));
return std::make_tuple(&program->info, module, spec.fetch_shader_data,
HashCombine(params.hash, 0));
return std::make_tuple(&program->info, module, fetch_shader, HashCombine(params.hash, 0));
}
it_pgm.value()->info.user_data = params.user_data;

auto& program = it_pgm.value();
auto& info = program->info;
info.RefreshFlatBuf();
const auto spec = Shader::StageSpecialization(info, runtime_info, profile, binding);
const auto fetch_shader = spec.fetch_shader_data;
size_t perm_idx = program->modules.size();
vk::ShaderModule module{};

const auto it = std::ranges::find(program->modules, spec, &Program::Module::spec);
const auto it = std::ranges::find_if(
program->modules, [&spec](const auto& module) { return module.spec.IsCompatible(spec); });
if (it == program->modules.end()) {
auto new_info = Shader::Info(stage, l_stage, params);
module = CompileModule(new_info, runtime_info, params.code, perm_idx, binding);
Expand All @@ -540,8 +542,7 @@ PipelineCache::Result PipelineCache::GetProgram(Stage stage, LogicalStage l_stag
module = it->module;
perm_idx = std::distance(program->modules.begin(), it);
}
return std::make_tuple(&info, module, spec.fetch_shader_data,
HashCombine(params.hash, perm_idx));
return std::make_tuple(&info, module, fetch_shader, HashCombine(params.hash, perm_idx));
}

std::optional<vk::ShaderModule> PipelineCache::ReplaceShader(vk::ShaderModule module,
Expand Down

0 comments on commit 8951ac9

Please sign in to comment.