1- // Copyright (C) 2021 - 2022 Advanced Micro Devices, Inc. All rights reserved.
1+ // Copyright (C) 2021 - 2023 Advanced Micro Devices, Inc. All rights reserved.
22//
33// Permission is hereby granted, free of charge, to any person obtaining a copy
44// of this software and associated documentation files (the "Software"), to deal
@@ -387,11 +387,26 @@ static RTCProcessType get_rtc_process_type()
387387 return RTCProcessType::DEFAULT;
388388}
389389
390+ static std::string gpu_arch_strip_flags (const std::string gpu_arch_with_flags)
391+ {
392+ return gpu_arch_with_flags.substr (0 , gpu_arch_with_flags.find (' :' ));
393+ }
394+
390395std::vector<char > cached_compile (const std::string& kernel_name,
391- const std::string& gpu_arch ,
396+ const std::string& gpu_arch_with_flags ,
392397 kernel_src_gen_t generate_src,
393398 const std::array<char , 32 >& generator_sum)
394399{
400+ // Supplied gpu arch may have extra flags on it
401+ // (e.g. gfx90a:sramecc+:xnack-), Strip those from the arch name
402+ // since omitting them will generate code that handles either
403+ // case.
404+ //
405+ // As of this writing, there are no known performance benefits to
406+ // including the flags. If that changes, we may need to be more
407+ // selective about which flags to strip.
408+ std::string gpu_arch = gpu_arch_strip_flags (gpu_arch_with_flags);
409+
395410 // check cache first
396411 std::vector<char > code;
397412 if (RTCCache::single)
@@ -534,8 +549,9 @@ void RTCCache::enable_write_mostly()
534549 sqlite3_step (wal_stmt.get ());
535550}
536551
537- void RTCCache::write_aot_cache (const std::string& output_path,
538- const std::array<char , 32 >& generator_sum)
552+ void RTCCache::write_aot_cache (const std::string& output_path,
553+ const std::array<char , 32 >& generator_sum,
554+ const std::vector<std::string>& gpu_archs)
539555{
540556 // remove the path if it already exists, since we want to output a
541557 // cleanly created file
@@ -559,6 +575,31 @@ void RTCCache::write_aot_cache(const std::string& output_path,
559575 + sqlite3_errmsg (db_user.get ()));
560576 sqlite3_reset (attach_stmt.get ());
561577
578+ // copy only the required arches, in case more are present in the
579+ // cache than we need
580+ auto create_temp_stmt = prepare_stmt (db_user,
581+ " CREATE TABLE IF NOT EXISTS temp.aot_arch ("
582+ " arch TEXT NOT NULL )" );
583+ if (sqlite3_step (create_temp_stmt.get ()) != SQLITE_DONE)
584+ throw std::runtime_error (std::string (" write_aot_cache create temp table: " )
585+ + sqlite3_errmsg (db_user.get ()));
586+
587+ auto insert_temp_stmt = prepare_stmt (db_user, " INSERT INTO temp.aot_arch VALUES ( ? )" );
588+ for (const auto & gpu_arch_with_flags : gpu_archs)
589+ {
590+ std::string gpu_arch = gpu_arch_strip_flags (gpu_arch_with_flags);
591+
592+ if (sqlite3_bind_text (
593+ insert_temp_stmt.get (), 1 , gpu_arch.c_str (), gpu_arch.size (), SQLITE_TRANSIENT)
594+ != SQLITE_OK)
595+ throw std::runtime_error (std::string (" write_aot_cache temp bind: " )
596+ + sqlite3_errmsg (db_user.get ()));
597+ if (sqlite3_step (insert_temp_stmt.get ()) != SQLITE_DONE)
598+ throw std::runtime_error (std::string (" write_aot_cache temp step: " )
599+ + sqlite3_errmsg (db_user.get ()));
600+ sqlite3_reset (insert_temp_stmt.get ());
601+ }
602+
562603 // copy the kernels over in a consistent order and zero out the timestamps
563604 auto copy_stmt = prepare_stmt (db_user,
564605 " INSERT INTO out_db.cache_v1 ("
@@ -571,11 +612,17 @@ void RTCCache::write_aot_cache(const std::string& output_path,
571612 " )"
572613 " SELECT kernel_name, arch, hip_version, generator_sum, code, 0 "
573614 " FROM cache_v1 "
574- " WHERE generator_sum = :generator_sum "
615+ " WHERE "
616+ " generator_sum = :generator_sum "
617+ " AND hip_version = :hip_version "
618+ " AND arch IN ("
619+ " SELECT arch FROM temp.aot_arch "
620+ " ) "
575621 " ORDER BY kernel_name, arch, hip_version" );
576622 if (sqlite3_bind_blob (
577623 copy_stmt.get (), 1 , generator_sum.data (), generator_sum.size (), SQLITE_TRANSIENT)
578- != SQLITE_OK)
624+ != SQLITE_OK
625+ || sqlite3_bind_int64 (copy_stmt.get (), 2 , HIP_VERSION) != SQLITE_OK)
579626 throw std::runtime_error (std::string (" write_aot_cache copy bind: " )
580627 + sqlite3_errmsg (db_user.get ()));
581628
0 commit comments