@ -515,6 +515,16 @@ private:
void DeclareCommon ( ) {
void DeclareCommon ( ) {
thread_id =
thread_id =
DeclareInputBuiltIn ( spv : : BuiltIn : : SubgroupLocalInvocationId , t_in_uint , " thread_id " ) ;
DeclareInputBuiltIn ( spv : : BuiltIn : : SubgroupLocalInvocationId , t_in_uint , " thread_id " ) ;
thread_masks [ 0 ] =
DeclareInputBuiltIn ( spv : : BuiltIn : : SubgroupEqMask , t_in_uint4 , " thread_eq_mask " ) ;
thread_masks [ 1 ] =
DeclareInputBuiltIn ( spv : : BuiltIn : : SubgroupGeMask , t_in_uint4 , " thread_ge_mask " ) ;
thread_masks [ 2 ] =
DeclareInputBuiltIn ( spv : : BuiltIn : : SubgroupGtMask , t_in_uint4 , " thread_gt_mask " ) ;
thread_masks [ 3 ] =
DeclareInputBuiltIn ( spv : : BuiltIn : : SubgroupLeMask , t_in_uint4 , " thread_le_mask " ) ;
thread_masks [ 4 ] =
DeclareInputBuiltIn ( spv : : BuiltIn : : SubgroupLtMask , t_in_uint4 , " thread_lt_mask " ) ;
}
}
void DeclareVertex ( ) {
void DeclareVertex ( ) {
@ -2175,6 +2185,13 @@ private:
return { OpLoad ( t_uint , thread_id ) , Type : : Uint } ;
return { OpLoad ( t_uint , thread_id ) , Type : : Uint } ;
}
}
template < std : : size_t index >
Expression ThreadMask ( Operation ) {
// TODO(Rodrigo): Handle devices with different warp sizes
const Id mask = thread_masks [ index ] ;
return { OpLoad ( t_uint , AccessElement ( t_in_uint , mask , 0 ) ) , Type : : Uint } ;
}
Expression ShuffleIndexed ( Operation operation ) {
Expression ShuffleIndexed ( Operation operation ) {
const Id value = AsFloat ( Visit ( operation [ 0 ] ) ) ;
const Id value = AsFloat ( Visit ( operation [ 0 ] ) ) ;
const Id index = AsUint ( Visit ( operation [ 1 ] ) ) ;
const Id index = AsUint ( Visit ( operation [ 1 ] ) ) ;
@ -2639,6 +2656,11 @@ private:
& SPIRVDecompiler : : Vote < & Module : : OpSubgroupAllEqualKHR > ,
& SPIRVDecompiler : : Vote < & Module : : OpSubgroupAllEqualKHR > ,
& SPIRVDecompiler : : ThreadId ,
& SPIRVDecompiler : : ThreadId ,
& SPIRVDecompiler : : ThreadMask < 0 > , // Eq
& SPIRVDecompiler : : ThreadMask < 1 > , // Ge
& SPIRVDecompiler : : ThreadMask < 2 > , // Gt
& SPIRVDecompiler : : ThreadMask < 3 > , // Le
& SPIRVDecompiler : : ThreadMask < 4 > , // Lt
& SPIRVDecompiler : : ShuffleIndexed ,
& SPIRVDecompiler : : ShuffleIndexed ,
& SPIRVDecompiler : : MemoryBarrierGL ,
& SPIRVDecompiler : : MemoryBarrierGL ,
@ -2763,6 +2785,7 @@ private:
Id workgroup_id { } ;
Id workgroup_id { } ;
Id local_invocation_id { } ;
Id local_invocation_id { } ;
Id thread_id { } ;
Id thread_id { } ;
std : : array < Id , 5 > thread_masks { } ; // eq, ge, gt, le, lt
VertexIndices in_indices ;
VertexIndices in_indices ;
VertexIndices out_indices ;
VertexIndices out_indices ;