MetalShader.metal 998 B

12345678910111213141516171819202122232425262728293031323334353637
  1. #include <metal_stdlib>
  2. using namespace metal;
  3. kernel void optimized_fp32(
  4. device float4 *buffer [[buffer(0)]],
  5. uint tid [[thread_position_in_grid]]
  6. ) {
  7. // 每个线程处理4个float元素
  8. float4 x = buffer[tid];
  9. // 展开循环增加计算强度
  10. for (int i = 0; i < 512; ++i) {
  11. // 每个循环8次浮点运算 x4 => 32 FLOPs/循环
  12. x = fma(x, float4(1.01), float4(0.97));
  13. x = fma(x, float4(0.99), float4(1.02));
  14. x = fma(x, float4(1.03), float4(0.98));
  15. x = fma(x, float4(0.96), float4(1.05));
  16. }
  17. buffer[tid] = x;
  18. }
  19. kernel void optimized_fp16(
  20. device half4 *buffer [[buffer(0)]],
  21. uint tid [[thread_position_in_grid]]
  22. ) {
  23. half4 x = buffer[tid];
  24. for (int i = 0; i < 512; ++i) {
  25. x = fma(x, half4(1.01), half4(0.97));
  26. x = fma(x, half4(0.99), half4(1.02));
  27. x = fma(x, half4(1.03), half4(0.98));
  28. x = fma(x, half4(0.96), half4(1.05));
  29. }
  30. buffer[tid] = x;
  31. }