@ -3,6 +3,7 @@
// Refer to the license.txt file included.
// Refer to the license.txt file included.
# include <algorithm>
# include <algorithm>
# include <functional>
# include <tuple>
# include <tuple>
# include <type_traits>
# include <type_traits>
@ -88,6 +89,26 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
return true ;
return true ;
}
}
/// Return true when all values in a range are equal
template < typename Range >
bool AreEqual ( const Range & range ) {
auto resolver { [ ] ( const auto & value ) { return value . Resolve ( ) ; } } ;
auto equal { [ ] ( const IR : : Value & lhs , const IR : : Value & rhs ) {
if ( lhs = = rhs ) {
return true ;
}
// Not equal, but try to match if they read the same constant buffer
if ( ! lhs . IsImmediate ( ) & & ! rhs . IsImmediate ( ) & &
lhs . Inst ( ) - > GetOpcode ( ) = = IR : : Opcode : : GetCbufU32 & &
rhs . Inst ( ) - > GetOpcode ( ) = = IR : : Opcode : : GetCbufU32 & &
lhs . Inst ( ) - > Arg ( 0 ) = = rhs . Inst ( ) - > Arg ( 0 ) & & lhs . Inst ( ) - > Arg ( 1 ) = = rhs . Inst ( ) - > Arg ( 1 ) ) {
return true ;
}
return false ;
} } ;
return std : : ranges : : adjacent_find ( range , std : : not_fn ( equal ) , resolver ) = = std : : end ( range ) ;
}
void FoldGetRegister ( IR : : Inst & inst ) {
void FoldGetRegister ( IR : : Inst & inst ) {
if ( inst . Arg ( 0 ) . Reg ( ) = = IR : : Reg : : RZ ) {
if ( inst . Arg ( 0 ) . Reg ( ) = = IR : : Reg : : RZ ) {
inst . ReplaceUsesWith ( IR : : Value { u32 { 0 } } ) ;
inst . ReplaceUsesWith ( IR : : Value { u32 { 0 } } ) ;
@ -100,6 +121,157 @@ void FoldGetPred(IR::Inst& inst) {
}
}
}
}
/// Replaces the XMAD pattern generated by an integer FMA
bool FoldXmadMultiplyAdd ( IR : : Block & block , IR : : Inst & inst ) {
/*
* We are looking for this specific pattern :
* % 6 = BitFieldUExtract % op_b , # 0 , # 16
* % 7 = BitFieldUExtract % op_a ' , # 16 , # 16
* % 8 = IMul32 % 6 , % 7
* % 10 = BitFieldUExtract % op_a ' , # 0 , # 16
* % 11 = BitFieldInsert % 8 , % 10 , # 16 , # 16
* % 15 = BitFieldUExtract % op_b , # 0 , # 16
* % 16 = BitFieldUExtract % op_a , # 0 , # 16
* % 17 = IMul32 % 15 , % 16
* % 18 = IAdd32 % 17 , % op_c
* % 22 = BitFieldUExtract % op_b , # 16 , # 16
* % 23 = BitFieldUExtract % 11 , # 16 , # 16
* % 24 = IMul32 % 22 , % 23
* % 25 = ShiftLeftLogical32 % 24 , # 16
* % 26 = ShiftLeftLogical32 % 11 , # 16
* % 27 = IAdd32 % 26 , % 18
* % result = IAdd32 % 25 , % 27
*
* And replace it with :
* % temp = IMul32 % op_a , % op_b
* % result = IAdd32 % temp , % op_c
*
* This optimization has been proven safe by Nvidia ' s compiler logic being reversed .
* ( If Nvidia generates this code from ' fma ( a , b , c ) ' , we can do the same in the reverse order . )
*/
const IR : : Value zero { 0u } ;
const IR : : Value sixteen { 16u } ;
IR : : Inst * const _25 { inst . Arg ( 0 ) . TryInstRecursive ( ) } ;
IR : : Inst * const _27 { inst . Arg ( 1 ) . TryInstRecursive ( ) } ;
if ( ! _25 | | ! _27 ) {
return false ;
}
if ( _27 - > GetOpcode ( ) ! = IR : : Opcode : : IAdd32 ) {
return false ;
}
if ( _25 - > GetOpcode ( ) ! = IR : : Opcode : : ShiftLeftLogical32 | | _25 - > Arg ( 1 ) ! = sixteen ) {
return false ;
}
IR : : Inst * const _24 { _25 - > Arg ( 0 ) . TryInstRecursive ( ) } ;
if ( ! _24 | | _24 - > GetOpcode ( ) ! = IR : : Opcode : : IMul32 ) {
return false ;
}
IR : : Inst * const _22 { _24 - > Arg ( 0 ) . TryInstRecursive ( ) } ;
IR : : Inst * const _23 { _24 - > Arg ( 1 ) . TryInstRecursive ( ) } ;
if ( ! _22 | | ! _23 ) {
return false ;
}
if ( _22 - > GetOpcode ( ) ! = IR : : Opcode : : BitFieldUExtract ) {
return false ;
}
if ( _23 - > GetOpcode ( ) ! = IR : : Opcode : : BitFieldUExtract ) {
return false ;
}
if ( _22 - > Arg ( 1 ) ! = sixteen | | _22 - > Arg ( 2 ) ! = sixteen ) {
return false ;
}
if ( _23 - > Arg ( 1 ) ! = sixteen | | _23 - > Arg ( 2 ) ! = sixteen ) {
return false ;
}
IR : : Inst * const _11 { _23 - > Arg ( 0 ) . TryInstRecursive ( ) } ;
if ( ! _11 | | _11 - > GetOpcode ( ) ! = IR : : Opcode : : BitFieldInsert ) {
return false ;
}
if ( _11 - > Arg ( 2 ) ! = sixteen | | _11 - > Arg ( 3 ) ! = sixteen ) {
return false ;
}
IR : : Inst * const _8 { _11 - > Arg ( 0 ) . TryInstRecursive ( ) } ;
IR : : Inst * const _10 { _11 - > Arg ( 1 ) . TryInstRecursive ( ) } ;
if ( ! _8 | | ! _10 ) {
return false ;
}
if ( _8 - > GetOpcode ( ) ! = IR : : Opcode : : IMul32 ) {
return false ;
}
if ( _10 - > GetOpcode ( ) ! = IR : : Opcode : : BitFieldUExtract ) {
return false ;
}
IR : : Inst * const _6 { _8 - > Arg ( 0 ) . TryInstRecursive ( ) } ;
IR : : Inst * const _7 { _8 - > Arg ( 1 ) . TryInstRecursive ( ) } ;
if ( ! _6 | | ! _7 ) {
return false ;
}
if ( _6 - > GetOpcode ( ) ! = IR : : Opcode : : BitFieldUExtract ) {
return false ;
}
if ( _7 - > GetOpcode ( ) ! = IR : : Opcode : : BitFieldUExtract ) {
return false ;
}
if ( _6 - > Arg ( 1 ) ! = zero | | _6 - > Arg ( 2 ) ! = sixteen ) {
return false ;
}
if ( _7 - > Arg ( 1 ) ! = sixteen | | _7 - > Arg ( 2 ) ! = sixteen ) {
return false ;
}
IR : : Inst * const _26 { _27 - > Arg ( 0 ) . TryInstRecursive ( ) } ;
IR : : Inst * const _18 { _27 - > Arg ( 1 ) . TryInstRecursive ( ) } ;
if ( ! _26 | | ! _18 ) {
return false ;
}
if ( _26 - > GetOpcode ( ) ! = IR : : Opcode : : ShiftLeftLogical32 | | _26 - > Arg ( 1 ) ! = sixteen ) {
return false ;
}
if ( _26 - > Arg ( 0 ) . InstRecursive ( ) ! = _11 ) {
return false ;
}
if ( _18 - > GetOpcode ( ) ! = IR : : Opcode : : IAdd32 ) {
return false ;
}
IR : : Inst * const _17 { _18 - > Arg ( 0 ) . TryInstRecursive ( ) } ;
if ( ! _17 | | _17 - > GetOpcode ( ) ! = IR : : Opcode : : IMul32 ) {
return false ;
}
IR : : Inst * const _15 { _17 - > Arg ( 0 ) . TryInstRecursive ( ) } ;
IR : : Inst * const _16 { _17 - > Arg ( 1 ) . TryInstRecursive ( ) } ;
if ( ! _15 | | ! _16 ) {
return false ;
}
if ( _15 - > GetOpcode ( ) ! = IR : : Opcode : : BitFieldUExtract ) {
return false ;
}
if ( _16 - > GetOpcode ( ) ! = IR : : Opcode : : BitFieldUExtract ) {
return false ;
}
if ( _15 - > Arg ( 1 ) ! = zero | | _16 - > Arg ( 1 ) ! = zero | | _10 - > Arg ( 1 ) ! = zero ) {
return false ;
}
if ( _15 - > Arg ( 2 ) ! = sixteen | | _16 - > Arg ( 2 ) ! = sixteen | | _10 - > Arg ( 2 ) ! = sixteen ) {
return false ;
}
const std : : array < IR : : Value , 3 > op_as {
_7 - > Arg ( 0 ) . Resolve ( ) ,
_16 - > Arg ( 0 ) . Resolve ( ) ,
_10 - > Arg ( 0 ) . Resolve ( ) ,
} ;
const std : : array < IR : : Value , 3 > op_bs {
_22 - > Arg ( 0 ) . Resolve ( ) ,
_6 - > Arg ( 0 ) . Resolve ( ) ,
_15 - > Arg ( 0 ) . Resolve ( ) ,
} ;
const IR : : U32 op_c { _18 - > Arg ( 1 ) } ;
if ( ! AreEqual ( op_as ) | | ! AreEqual ( op_bs ) ) {
return false ;
}
IR : : IREmitter ir { block , IR : : Block : : InstructionList : : s_iterator_to ( inst ) } ;
inst . ReplaceUsesWith ( ir . IAdd ( ir . IMul ( IR : : U32 { op_as [ 0 ] } , IR : : U32 { op_bs [ 1 ] } ) , op_c ) ) ;
return true ;
}
/// Replaces the pattern generated by two XMAD multiplications
/// Replaces the pattern generated by two XMAD multiplications
bool FoldXmadMultiply ( IR : : Block & block , IR : : Inst & inst ) {
bool FoldXmadMultiply ( IR : : Block & block , IR : : Inst & inst ) {
/*
/*
@ -179,6 +351,9 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) {
if ( FoldXmadMultiply ( block , inst ) ) {
if ( FoldXmadMultiply ( block , inst ) ) {
return ;
return ;
}
}
if ( FoldXmadMultiplyAdd ( block , inst ) ) {
return ;
}
}
}
}
}