import Metal import MetalKit import Foundation class OptimizedMetalBenchmark { private let device: MTLDevice private let commandQueue: MTLCommandQueue private let pipelineStateFP32: MTLComputePipelineState private let pipelineStateFP16: MTLComputePipelineState // 优化参数 let elementsPerThread = 4 // 每个线程处理4个元素 let loopUnrollFactor = 512 // 循环展开因子 let arrayLength = 1 << 26 // 调整为4M元素(减少内存压力) let iterations = 100 let threadgroupWidth = 256 init() { guard let device = MTLCreateSystemDefaultDevice(), let commandQueue = device.makeCommandQueue() else { fatalError("Metal初始化失败") } self.device = device self.commandQueue = commandQueue let library = try! device.makeLibrary(source: metalCode, options: nil) let fp32Function = library.makeFunction(name: "optimized_fp32")! let fp16Function = library.makeFunction(name: "optimized_fp16")! do { pipelineStateFP32 = try device.makeComputePipelineState(function: fp32Function) pipelineStateFP16 = try device.makeComputePipelineState(function: fp16Function) } catch { fatalError("创建计算管线失败: \(error)") } print("优化版Metal基准测试初始化完成 - 设备: \(device.name)") } func runOptimizedBenchmark() { print("\n运行优化版FP32测试...") let fp32Result = runTest(pipeline: pipelineStateFP32, precision: "FP32") print("\n运行优化版FP16测试...") let fp16Result = runTest(pipeline: pipelineStateFP16, precision: "FP16") print("\n最终结果:") print("FP32峰值性能: \(String(format: "%.2f", fp32Result)) TFLOPS") print("FP16峰值性能: \(String(format: "%.2f", fp16Result)) TFLOPS") } private func runTest(pipeline: MTLComputePipelineState, precision: String) -> Double { let bufferSize = arrayLength * MemoryLayout.size guard let buffer = device.makeBuffer(length: bufferSize, options: .storageModePrivate) else { fatalError("缓冲区创建失败") } let threadsPerGrid = arrayLength / elementsPerThread let threadgroups = MTLSize( width: (threadsPerGrid + threadgroupWidth - 1) / threadgroupWidth, height: 1, depth: 1 ) let threadgroupSize = MTLSize(width: threadgroupWidth, height: 1, depth: 1) // 预热 for _ in 0..<3 { runKernel(pipeline: pipeline, buffer: buffer, threadsPerGrid: threadsPerGrid, threadgroups: threadgroups, threadgroupSize: threadgroupSize) } // 正式测试 let start = CFAbsoluteTimeGetCurrent() for _ in 0..