@ -1123,15 +1123,7 @@ private:
}
if ( const auto gmem = std : : get_if < GmemNode > ( & * node ) ) {
const Id gmem_buffer = global_buffers . at ( gmem - > GetDescriptor ( ) ) ;
const Id real = AsUint ( Visit ( gmem - > GetRealAddress ( ) ) ) ;
const Id base = AsUint ( Visit ( gmem - > GetBaseAddress ( ) ) ) ;
Id offset = OpISub ( t_uint , real , base ) ;
offset = OpUDiv ( t_uint , offset , Constant ( t_uint , 4U ) ) ;
return { OpLoad ( t_float ,
OpAccessChain ( t_gmem_float , gmem_buffer , Constant ( t_uint , 0U ) , offset ) ) ,
Type : : Float } ;
return { OpLoad ( t_uint , GetGlobalMemoryPointer ( * gmem ) ) , Type : : Uint } ;
}
if ( const auto lmem = std : : get_if < LmemNode > ( & * node ) ) {
@ -1142,10 +1134,7 @@ private:
}
if ( const auto smem = std : : get_if < SmemNode > ( & * node ) ) {
Id address = AsUint ( Visit ( smem - > GetAddress ( ) ) ) ;
address = OpShiftRightLogical ( t_uint , address , Constant ( t_uint , 2U ) ) ;
const Id pointer = OpAccessChain ( t_smem_uint , shared_memory , address ) ;
return { OpLoad ( t_uint , pointer ) , Type : : Uint } ;
return { OpLoad ( t_uint , GetSharedMemoryPointer ( * smem ) ) , Type : : Uint } ;
}
if ( const auto internal_flag = std : : get_if < InternalFlagNode > ( & * node ) ) {
@ -1339,20 +1328,10 @@ private:
target = { OpAccessChain ( t_prv_float , local_memory , address ) , Type : : Float } ;
} else if ( const auto smem = std : : get_if < SmemNode > ( & * dest ) ) {
ASSERT ( stage = = ShaderType : : Compute ) ;
Id address = AsUint ( Visit ( smem - > GetAddress ( ) ) ) ;
address = OpShiftRightLogical ( t_uint , address , Constant ( t_uint , 2U ) ) ;
target = { OpAccessChain ( t_smem_uint , shared_memory , address ) , Type : : Uint } ;
target = { GetSharedMemoryPointer ( * smem ) , Type : : Uint } ;
} else if ( const auto gmem = std : : get_if < GmemNode > ( & * dest ) ) {
const Id real = AsUint ( Visit ( gmem - > GetRealAddress ( ) ) ) ;
const Id base = AsUint ( Visit ( gmem - > GetBaseAddress ( ) ) ) ;
const Id diff = OpISub ( t_uint , real , base ) ;
const Id offset = OpShiftRightLogical ( t_uint , diff , Constant ( t_uint , 2 ) ) ;
const Id gmem_buffer = global_buffers . at ( gmem - > GetDescriptor ( ) ) ;
target = { OpAccessChain ( t_gmem_float , gmem_buffer , Constant ( t_uint , 0 ) , offset ) ,
Type : : Float } ;
target = { GetGlobalMemoryPointer ( * gmem ) , Type : : Uint } ;
} else {
UNIMPLEMENTED ( ) ;
@ -1804,11 +1783,16 @@ private:
return { } ;
}
Expression UAtomicAdd ( Operation operation ) {
const auto & smem = std : : get < SmemNode > ( * operation [ 0 ] ) ;
Id address = AsUint ( Visit ( smem . GetAddress ( ) ) ) ;
address = OpShiftRightLogical ( t_uint , address , Constant ( t_uint , 2U ) ) ;
const Id pointer = OpAccessChain ( t_smem_uint , shared_memory , address ) ;
Expression AtomicAdd ( Operation operation ) {
Id pointer ;
if ( const auto smem = std : : get_if < SmemNode > ( & * operation [ 0 ] ) ) {
pointer = GetSharedMemoryPointer ( * smem ) ;
} else if ( const auto gmem = std : : get_if < GmemNode > ( & * operation [ 0 ] ) ) {
pointer = GetGlobalMemoryPointer ( * gmem ) ;
} else {
UNREACHABLE ( ) ;
return { Constant ( t_uint , 0 ) , Type : : Uint } ;
}
const Id scope = Constant ( t_uint , static_cast < u32 > ( spv : : Scope : : Device ) ) ;
const Id semantics = Constant ( t_uint , 0U ) ;
@ -2243,6 +2227,22 @@ private:
return { } ;
}
Id GetGlobalMemoryPointer ( const GmemNode & gmem ) {
const Id real = AsUint ( Visit ( gmem . GetRealAddress ( ) ) ) ;
const Id base = AsUint ( Visit ( gmem . GetBaseAddress ( ) ) ) ;
const Id diff = OpISub ( t_uint , real , base ) ;
const Id offset = OpShiftRightLogical ( t_uint , diff , Constant ( t_uint , 2 ) ) ;
const Id buffer = global_buffers . at ( gmem . GetDescriptor ( ) ) ;
return OpAccessChain ( t_gmem_uint , buffer , Constant ( t_uint , 0 ) , offset ) ;
}
Id GetSharedMemoryPointer ( const SmemNode & smem ) {
ASSERT ( stage = = ShaderType : : Compute ) ;
Id address = AsUint ( Visit ( smem . GetAddress ( ) ) ) ;
address = OpShiftRightLogical ( t_uint , address , Constant ( t_uint , 2U ) ) ;
return OpAccessChain ( t_smem_uint , shared_memory , address ) ;
}
static constexpr std : : array operation_decompilers = {
& SPIRVDecompiler : : Assign ,
@ -2389,7 +2389,7 @@ private:
& SPIRVDecompiler : : AtomicImageXor ,
& SPIRVDecompiler : : AtomicImageExchange ,
& SPIRVDecompiler : : U AtomicAdd,
& SPIRVDecompiler : : AtomicAdd,
& SPIRVDecompiler : : Branch ,
& SPIRVDecompiler : : BranchIndirect ,
@ -2485,9 +2485,9 @@ private:
Id t_smem_uint { } ;
const Id t_gmem_ floa t = TypePointer ( spv : : StorageClass : : StorageBuffer , t_ floa t) ;
const Id t_gmem_ uin t = TypePointer ( spv : : StorageClass : : StorageBuffer , t_ uin t) ;
const Id t_gmem_array =
Name ( Decorate ( TypeRuntimeArray ( t_ floa t) , spv : : Decoration : : ArrayStride , 4U ) , " GmemArray " ) ;
Name ( Decorate ( TypeRuntimeArray ( t_ uin t) , spv : : Decoration : : ArrayStride , 4U ) , " GmemArray " ) ;
const Id t_gmem_struct = MemberDecorate (
Decorate ( TypeStruct ( t_gmem_array ) , spv : : Decoration : : Block ) , 0 , spv : : Decoration : : Offset , 0 ) ;
const Id t_gmem_ssbo = TypePointer ( spv : : StorageClass : : StorageBuffer , t_gmem_struct ) ;