main.swift 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import Metal
  2. import MetalKit
  3. import Foundation
  4. class OptimizedMetalBenchmark {
  5. private let device: MTLDevice
  6. private let commandQueue: MTLCommandQueue
  7. private let pipelineStateFP32: MTLComputePipelineState
  8. private let pipelineStateFP16: MTLComputePipelineState
  9. // 优化参数
  10. let elementsPerThread = 4 // 每个线程处理4个元素
  11. let loopUnrollFactor = 512 // 循环展开因子
  12. let arrayLength = 1 << 26 // 调整为4M元素(减少内存压力)
  13. let iterations = 100
  14. let threadgroupWidth = 256
  15. init() {
  16. guard let device = MTLCreateSystemDefaultDevice(),
  17. let commandQueue = device.makeCommandQueue() else {
  18. fatalError("Metal初始化失败")
  19. }
  20. self.device = device
  21. self.commandQueue = commandQueue
  22. let library = try! device.makeLibrary(source: metalCode, options: nil)
  23. let fp32Function = library.makeFunction(name: "optimized_fp32")!
  24. let fp16Function = library.makeFunction(name: "optimized_fp16")!
  25. do {
  26. pipelineStateFP32 = try device.makeComputePipelineState(function: fp32Function)
  27. pipelineStateFP16 = try device.makeComputePipelineState(function: fp16Function)
  28. } catch {
  29. fatalError("创建计算管线失败: \(error)")
  30. }
  31. print("优化版Metal基准测试初始化完成 - 设备: \(device.name)")
  32. }
  33. func runOptimizedBenchmark() {
  34. print("\n运行优化版FP32测试...")
  35. let fp32Result = runTest(pipeline: pipelineStateFP32, precision: "FP32")
  36. print("\n运行优化版FP16测试...")
  37. let fp16Result = runTest(pipeline: pipelineStateFP16, precision: "FP16")
  38. print("\n最终结果:")
  39. print("FP32峰值性能: \(String(format: "%.2f", fp32Result)) TFLOPS")
  40. print("FP16峰值性能: \(String(format: "%.2f", fp16Result)) TFLOPS")
  41. }
  42. private func runTest(pipeline: MTLComputePipelineState, precision: String) -> Double {
  43. let bufferSize = arrayLength * MemoryLayout<Float>.size
  44. guard let buffer = device.makeBuffer(length: bufferSize, options: .storageModePrivate) else {
  45. fatalError("缓冲区创建失败")
  46. }
  47. let threadsPerGrid = arrayLength / elementsPerThread
  48. let threadgroups = MTLSize(
  49. width: (threadsPerGrid + threadgroupWidth - 1) / threadgroupWidth,
  50. height: 1,
  51. depth: 1
  52. )
  53. let threadgroupSize = MTLSize(width: threadgroupWidth, height: 1, depth: 1)
  54. // 预热
  55. for _ in 0..<3 {
  56. runKernel(pipeline: pipeline, buffer: buffer,
  57. threadsPerGrid: threadsPerGrid,
  58. threadgroups: threadgroups,
  59. threadgroupSize: threadgroupSize)
  60. }
  61. // 正式测试
  62. let start = CFAbsoluteTimeGetCurrent()
  63. for _ in 0..<iterations {
  64. runKernel(pipeline: pipeline, buffer: buffer,
  65. threadsPerGrid: threadsPerGrid,
  66. threadgroups: threadgroups,
  67. threadgroupSize: threadgroupSize)
  68. }
  69. let elapsed = CFAbsoluteTimeGetCurrent() - start
  70. // 计算FLOPs: 每个线程处理4个元素,每次循环128次,每次循环8 FLOPs
  71. let totalFLOPs = Double(arrayLength) * Double(loopUnrollFactor) * 8.0 * Double(iterations)
  72. let tflops = (totalFLOPs / elapsed) / 1e12
  73. print("\(precision) 结果:")
  74. print("- 总时间: \(String(format: "%.3f", elapsed))秒")
  75. print("- 理论计算量: \(String(format: "%.1f", totalFLOPs/1e12)) TFLOP")
  76. print("- 实测性能: \(String(format: "%.2f", tflops)) TFLOPS")
  77. return tflops
  78. }
  79. private func runKernel(pipeline: MTLComputePipelineState,
  80. buffer: MTLBuffer,
  81. threadsPerGrid: Int,
  82. threadgroups: MTLSize,
  83. threadgroupSize: MTLSize) {
  84. guard let commandBuffer = commandQueue.makeCommandBuffer(),
  85. let encoder = commandBuffer.makeComputeCommandEncoder() else {
  86. fatalError("创建命令对象失败")
  87. }
  88. encoder.setComputePipelineState(pipeline)
  89. encoder.setBuffer(buffer, offset: 0, index: 0)
  90. encoder.dispatchThreadgroups(threadgroups,
  91. threadsPerThreadgroup: threadgroupSize)
  92. encoder.endEncoding()
  93. commandBuffer.commit()
  94. commandBuffer.waitUntilCompleted()
  95. }
  96. }
  97. // 运行优化版测试
  98. let benchmark = OptimizedMetalBenchmark()
  99. benchmark.runOptimizedBenchmark()