diff --git a/src/video_core/shader/control_flow.cpp b/src/video_core/shader/control_flow.cpp index d47c63d9fb..b427ac8735 100644 --- a/src/video_core/shader/control_flow.cpp +++ b/src/video_core/shader/control_flow.cpp @@ -16,7 +16,9 @@ #include "video_core/shader/shader_ir.h" namespace VideoCommon::Shader { + namespace { + using Tegra::Shader::Instruction; using Tegra::Shader::OpCode; @@ -68,15 +70,15 @@ struct CFGRebuildState { const ProgramCode& program_code; ConstBufferLocker& locker; u32 start{}; - std::vector block_info{}; - std::list inspect_queries{}; - std::list queries{}; - std::unordered_map registered{}; - std::set labels{}; - std::map ssy_labels{}; - std::map pbk_labels{}; - std::unordered_map stacks{}; - ASTManager* manager; + std::vector block_info; + std::list inspect_queries; + std::list queries; + std::unordered_map registered; + std::set labels; + std::map ssy_labels; + std::map pbk_labels; + std::unordered_map stacks; + ASTManager* manager{}; }; enum class BlockCollision : u32 { None, Found, Inside }; @@ -109,7 +111,7 @@ BlockInfo& CreateBlockInfo(CFGRebuildState& state, u32 start, u32 end) { } Pred GetPredicate(u32 index, bool negated) { - return static_cast(index + (negated ? 8 : 0)); + return static_cast(static_cast(index) + (negated ? 8ULL : 0ULL)); } /** @@ -136,15 +138,13 @@ struct BranchIndirectInfo { s32 relative_position{}; }; -std::optional TrackBranchIndirectInfo(const CFGRebuildState& state, - u32 start_address, u32 current_position) { - const u32 shader_start = state.start; - u32 pos = current_position; - BranchIndirectInfo result{}; - u64 track_register = 0; +struct BufferInfo { + u32 index; + u32 offset; +}; - // Step 0 Get BRX Info - const Instruction instr = {state.program_code[pos]}; +std::optional> GetBRXInfo(const CFGRebuildState& state, u32& pos) { + const Instruction instr = state.program_code[pos]; const auto opcode = OpCode::Decode(instr); if (opcode->get().GetId() != OpCode::Id::BRX) { return std::nullopt; @@ -152,86 +152,94 @@ std::optional TrackBranchIndirectInfo(const CFGRebuildState& if (instr.brx.constant_buffer != 0) { return std::nullopt; } - track_register = instr.gpr8.Value(); - result.relative_position = instr.brx.GetBranchExtend(); - pos--; - bool found_track = false; + --pos; + return std::make_pair(instr.brx.GetBranchExtend(), instr.gpr8.Value()); +} - // Step 1 Track LDC - while (pos >= shader_start) { - if (IsSchedInstruction(pos, shader_start)) { - pos--; - continue; - } - const Instruction instr = {state.program_code[pos]}; - const auto opcode = OpCode::Decode(instr); - if (opcode->get().GetId() == OpCode::Id::LD_C) { - if (instr.gpr0.Value() == track_register && - instr.ld_c.type.Value() == Tegra::Shader::UniformType::Single) { - result.buffer = instr.cbuf36.index.Value(); - result.offset = static_cast(instr.cbuf36.GetOffset()); - track_register = instr.gpr8.Value(); - pos--; - found_track = true; - break; - } - } - pos--; - } - - if (!found_track) { - return std::nullopt; - } - found_track = false; - - // Step 2 Track SHL - while (pos >= shader_start) { - if (IsSchedInstruction(pos, shader_start)) { - pos--; +template +// requires std::predicate +// requires std::invocable +std::optional TrackInstruction(const CFGRebuildState& state, u32& pos, TestCallable test, + PackCallable pack) { + for (; pos >= state.start; --pos) { + if (IsSchedInstruction(pos, state.start)) { continue; } const Instruction instr = state.program_code[pos]; const auto opcode = OpCode::Decode(instr); - if (opcode->get().GetId() == OpCode::Id::SHL_IMM) { - if (instr.gpr0.Value() == track_register) { - track_register = instr.gpr8.Value(); - pos--; - found_track = true; - break; - } - } - pos--; - } - - if (!found_track) { - return std::nullopt; - } - found_track = false; - - // Step 3 Track IMNMX - while (pos >= shader_start) { - if (IsSchedInstruction(pos, shader_start)) { - pos--; + if (!opcode) { continue; } - const Instruction instr = state.program_code[pos]; - const auto opcode = OpCode::Decode(instr); - if (opcode->get().GetId() == OpCode::Id::IMNMX_IMM) { - if (instr.gpr0.Value() == track_register) { - track_register = instr.gpr8.Value(); - result.entries = instr.alu.GetSignedImm20_20() + 1; - pos--; - found_track = true; - break; - } + if (test(instr, opcode->get())) { + --pos; + return std::make_optional(pack(instr, opcode->get())); } - pos--; } + return std::nullopt; +} - if (!found_track) { +std::optional> TrackLDC(const CFGRebuildState& state, u32& pos, + u64 brx_tracked_register) { + return TrackInstruction>( + state, pos, + [brx_tracked_register](auto instr, const auto& opcode) { + return opcode.GetId() == OpCode::Id::LD_C && + instr.gpr0.Value() == brx_tracked_register && + instr.ld_c.type.Value() == Tegra::Shader::UniformType::Single; + }, + [](auto instr, const auto& opcode) { + const BufferInfo info = {static_cast(instr.cbuf36.index.Value()), + static_cast(instr.cbuf36.GetOffset())}; + return std::make_pair(info, instr.gpr8.Value()); + }); +} + +std::optional TrackSHLRegister(const CFGRebuildState& state, u32& pos, + u64 ldc_tracked_register) { + return TrackInstruction(state, pos, + [ldc_tracked_register](auto instr, const auto& opcode) { + return opcode.GetId() == OpCode::Id::SHL_IMM && + instr.gpr0.Value() == ldc_tracked_register; + }, + [](auto instr, const auto&) { return instr.gpr8.Value(); }); +} + +std::optional TrackIMNMXValue(const CFGRebuildState& state, u32& pos, + u64 shl_tracked_register) { + return TrackInstruction(state, pos, + [shl_tracked_register](auto instr, const auto& opcode) { + return opcode.GetId() == OpCode::Id::IMNMX_IMM && + instr.gpr0.Value() == shl_tracked_register; + }, + [](auto instr, const auto&) { + return static_cast(instr.alu.GetSignedImm20_20() + 1); + }); +} + +std::optional TrackBranchIndirectInfo(const CFGRebuildState& state, u32 pos) { + const auto brx_info = GetBRXInfo(state, pos); + if (!brx_info) { return std::nullopt; } - return result; + const auto [relative_position, brx_tracked_register] = *brx_info; + + const auto ldc_info = TrackLDC(state, pos, brx_tracked_register); + if (!ldc_info) { + return std::nullopt; + } + const auto [buffer_info, ldc_tracked_register] = *ldc_info; + + const auto shl_tracked_register = TrackSHLRegister(state, pos, ldc_tracked_register); + if (!shl_tracked_register) { + return std::nullopt; + } + + const auto entries = TrackIMNMXValue(state, pos, *shl_tracked_register); + if (!entries) { + return std::nullopt; + } + + return BranchIndirectInfo{buffer_info.index, buffer_info.offset, *entries, relative_position}; } std::pair ParseCode(CFGRebuildState& state, u32 address) { @@ -420,30 +428,30 @@ std::pair ParseCode(CFGRebuildState& state, u32 address) break; } case OpCode::Id::BRX: { - auto tmp = TrackBranchIndirectInfo(state, address, offset); - if (tmp) { - auto result = *tmp; - std::vector branches{}; - s32 pc_target = offset + result.relative_position; - for (u32 i = 0; i < result.entries; i++) { - auto k = state.locker.ObtainKey(result.buffer, result.offset + i * 4); - if (!k) { - return {ParseResult::AbnormalFlow, parse_info}; - } - u32 value = *k; - u32 target = static_cast((value >> 3) + pc_target); - insert_label(state, target); - branches.emplace_back(value, target); - } - parse_info.end_address = offset; - parse_info.branch_info = MakeBranchInfo( - static_cast(instr.gpr8.Value()), std::move(branches)); - - return {ParseResult::ControlCaught, parse_info}; - } else { + const auto tmp = TrackBranchIndirectInfo(state, offset); + if (!tmp) { LOG_WARNING(HW_GPU, "BRX Track Unsuccesful"); + return {ParseResult::AbnormalFlow, parse_info}; } - return {ParseResult::AbnormalFlow, parse_info}; + + const auto result = *tmp; + const s32 pc_target = offset + result.relative_position; + std::vector branches; + for (u32 i = 0; i < result.entries; i++) { + auto key = state.locker.ObtainKey(result.buffer, result.offset + i * 4); + if (!key) { + return {ParseResult::AbnormalFlow, parse_info}; + } + u32 value = *key; + u32 target = static_cast((value >> 3) + pc_target); + insert_label(state, target); + branches.emplace_back(value, target); + } + parse_info.end_address = offset; + parse_info.branch_info = MakeBranchInfo( + static_cast(instr.gpr8.Value()), std::move(branches)); + + return {ParseResult::ControlCaught, parse_info}; } default: break;