GPUKernelContest/cp_template/sort_pair_algorithm.maca

275 lines
10 KiB
Plaintext
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "test_utils.h"
#include "performance_utils.h"
#include "yaml_reporter.h"
#include <iostream>
#include <vector>
#include <iomanip>
// ============================================================================
// 实现标记宏 - 参赛者修改实现时请将此宏设为0
// ============================================================================
#ifndef USE_DEFAULT_REF_IMPL
#define USE_DEFAULT_REF_IMPL 1 // 1=默认实现, 0=参赛者自定义实现
#endif
#if USE_DEFAULT_REF_IMPL
#include <thrust/sort.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/tuple.h>
#endif
// ============================================================================
// SortPair算法实现接口
// 参赛者需要替换Thrust实现为自己的高性能kernel
// ============================================================================
template <typename KeyType, typename ValueType>
class SortPairAlgorithm {
public:
// 主要接口函数 - 参赛者需要实现这个函数
void sort(const KeyType* d_keys_in, KeyType* d_keys_out,
const ValueType* d_values_in, ValueType* d_values_out,
int num_items, bool descending) {
#if !USE_DEFAULT_REF_IMPL
// ========================================
// 参赛者自定义实现区域
// ========================================
// TODO: 参赛者在此实现自己的高性能排序算法
// 示例参赛者可以调用1个或多个自定义kernel
// preprocessKernel<<<grid, block>>>(d_keys_in, d_values_in, num_items);
// mainSortKernel<<<grid, block>>>(d_keys_out, d_values_out, num_items, descending);
// postprocessKernel<<<grid, block>>>(d_keys_out, d_values_out, num_items);
#else
// ========================================
// 默认基准实现
// ========================================
MACA_CHECK(mcMemcpy(d_keys_out, d_keys_in, num_items * sizeof(KeyType), mcMemcpyDeviceToDevice));
MACA_CHECK(mcMemcpy(d_values_out, d_values_in, num_items * sizeof(ValueType), mcMemcpyDeviceToDevice));
auto key_ptr = thrust::device_pointer_cast(d_keys_out);
auto value_ptr = thrust::device_pointer_cast(d_values_out);
if (descending) {
thrust::stable_sort_by_key(thrust::device, key_ptr, key_ptr + num_items, value_ptr, thrust::greater<KeyType>());
} else {
thrust::stable_sort_by_key(thrust::device, key_ptr, key_ptr + num_items, value_ptr, thrust::less<KeyType>());
}
#endif
}
// 获取当前实现状态
static const char* getImplementationStatus() {
#if USE_DEFAULT_REF_IMPL
return "DEFAULT_REF_IMPL";
#else
return "CUSTOM_IMPL";
#endif
}
private:
// 参赛者可以在这里添加辅助函数和成员变量
// 例如临时缓冲区、多个kernel函数、流等
};
// ============================================================================
// 测试和性能评估
// ============================================================================
bool testCorrectness() {
std::cout << "SortPair 正确性测试..." << std::endl;
TestDataGenerator generator;
SortPairAlgorithm<float, uint32_t> algorithm;
// 测试小规模数据
int size = 10000;
auto keys = generator.generateRandomFloats(size);
auto values = generator.generateRandomUint32(size);
// 分配GPU内存
float *d_keys_in, *d_keys_out;
uint32_t *d_values_in, *d_values_out;
MACA_CHECK(mcMalloc(&d_keys_in, size * sizeof(float)));
MACA_CHECK(mcMalloc(&d_keys_out, size * sizeof(float)));
MACA_CHECK(mcMalloc(&d_values_in, size * sizeof(uint32_t)));
MACA_CHECK(mcMalloc(&d_values_out, size * sizeof(uint32_t)));
MACA_CHECK(mcMemcpy(d_keys_in, keys.data(), size * sizeof(float), mcMemcpyHostToDevice));
MACA_CHECK(mcMemcpy(d_values_in, values.data(), size * sizeof(uint32_t), mcMemcpyHostToDevice));
// 测试升序和降序
bool allPassed = true;
for (bool descending : {false, true}) {
std::cout << " " << (descending ? "降序" : "升序") << " 测试..." << std::endl;
// CPU参考结果
auto cpu_keys = keys;
auto cpu_values = values;
cpuSortPair(cpu_keys, cpu_values, descending);
// GPU算法结果
algorithm.sort(d_keys_in, d_keys_out, d_values_in, d_values_out, size, descending);
// 获取结果
std::vector<float> gpu_keys(size);
std::vector<uint32_t> gpu_values(size);
MACA_CHECK(mcMemcpy(gpu_keys.data(), d_keys_out, size * sizeof(float), mcMemcpyDeviceToHost));
MACA_CHECK(mcMemcpy(gpu_values.data(), d_values_out, size * sizeof(uint32_t), mcMemcpyDeviceToHost));
// 验证结果
bool keysMatch = compareArrays(cpu_keys, gpu_keys, 1e-5);
bool valuesMatch = compareArrays(cpu_values, gpu_values);
if (!keysMatch || !valuesMatch) {
std::cout << " 失败: 结果不匹配" << std::endl;
allPassed = false;
} else {
std::cout << " 通过" << std::endl;
}
}
// 清理内存
mcFree(d_keys_in);
mcFree(d_keys_out);
mcFree(d_values_in);
mcFree(d_values_out);
return allPassed;
}
void benchmarkPerformance() {
PerformanceDisplay::printSortPairHeader();
TestDataGenerator generator;
PerformanceMeter meter;
SortPairAlgorithm<float, uint32_t> algorithm;
const int WARMUP_ITERATIONS = 5;
const int BENCHMARK_ITERATIONS = 10;
// 用于YAML报告的数据收集
std::vector<std::map<std::string, std::string>> perf_data;
for (int i = 0; i < NUM_TEST_SIZES; i++) {
int size = TEST_SIZES[i];
// 生成测试数据
auto keys = generator.generateRandomFloats(size);
auto values = generator.generateRandomUint32(size);
// 分配GPU内存
float *d_keys_in, *d_keys_out;
uint32_t *d_values_in, *d_values_out;
MACA_CHECK(mcMalloc(&d_keys_in, size * sizeof(float)));
MACA_CHECK(mcMalloc(&d_keys_out, size * sizeof(float)));
MACA_CHECK(mcMalloc(&d_values_in, size * sizeof(uint32_t)));
MACA_CHECK(mcMalloc(&d_values_out, size * sizeof(uint32_t)));
MACA_CHECK(mcMemcpy(d_keys_in, keys.data(), size * sizeof(float), mcMemcpyHostToDevice));
MACA_CHECK(mcMemcpy(d_values_in, values.data(), size * sizeof(uint32_t), mcMemcpyHostToDevice));
float asc_time = 0, desc_time = 0;
// 测试升序和降序
for (bool descending : {false, true}) {
// Warmup阶段
for (int iter = 0; iter < WARMUP_ITERATIONS; iter++) {
algorithm.sort(d_keys_in, d_keys_out, d_values_in, d_values_out, size, descending);
}
// 正式测试阶段
float total_time = 0;
for (int iter = 0; iter < BENCHMARK_ITERATIONS; iter++) {
meter.startTiming();
algorithm.sort(d_keys_in, d_keys_out, d_values_in, d_values_out, size, descending);
total_time += meter.stopTiming();
}
float avg_time = total_time / BENCHMARK_ITERATIONS;
if (descending) {
desc_time = avg_time;
} else {
asc_time = avg_time;
}
}
// 计算性能指标
auto asc_metrics = PerformanceCalculator::calculateSortPair(size, asc_time);
auto desc_metrics = PerformanceCalculator::calculateSortPair(size, desc_time);
// 显示性能数据
PerformanceDisplay::printSortPairData(size, asc_time, desc_time, asc_metrics, desc_metrics);
// 收集YAML报告数据
auto entry = YAMLPerformanceReporter::createEntry();
entry["data_size"] = std::to_string(size);
entry["asc_time_ms"] = std::to_string(asc_time);
entry["desc_time_ms"] = std::to_string(desc_time);
entry["asc_throughput_gps"] = std::to_string(asc_metrics.throughput_gps);
entry["desc_throughput_gps"] = std::to_string(desc_metrics.throughput_gps);
entry["key_type"] = "float";
entry["value_type"] = "uint32_t";
perf_data.push_back(entry);
// 清理内存
mcFree(d_keys_in);
mcFree(d_keys_out);
mcFree(d_values_in);
mcFree(d_values_out);
}
// 生成YAML性能报告
YAMLPerformanceReporter::generateSortPairYAML(perf_data, "sort_pair_performance.yaml");
PerformanceDisplay::printSavedMessage("sort_pair_performance.yaml");
}
// ============================================================================
// 主函数
// ============================================================================
int main(int argc, char* argv[]) {
std::cout << "=== SortPair 算法测试 ===" << std::endl;
// 检查参数
std::string mode = "all";
if (argc > 1) {
mode = argv[1];
}
bool correctness_passed = true;
bool performance_completed = true;
try {
if (mode == "correctness" || mode == "all") {
correctness_passed = testCorrectness();
}
if (mode == "performance" || mode == "all") {
if (correctness_passed || mode == "performance") {
benchmarkPerformance();
} else {
std::cout << "跳过性能测试,因为正确性测试未通过" << std::endl;
performance_completed = false;
}
}
std::cout << "\n=== 测试完成 ===" << std::endl;
std::cout << "实现状态: " << SortPairAlgorithm<float, uint32_t>::getImplementationStatus() << std::endl;
if (mode == "all") {
std::cout << "正确性: " << (correctness_passed ? "通过" : "失败") << std::endl;
std::cout << "性能测试: " << (performance_completed ? "完成" : "跳过") << std::endl;
}
return correctness_passed ? 0 : 1;
} catch (const std::exception& e) {
std::cerr << "测试出错: " << e.what() << std::endl;
return 1;
}
}