Skip to content

Commit 4f83816

Browse files
[PJRT:C] Add PJRT_Executable_Fingerprint to support AOT compilation.
PiperOrigin-RevId: 575543476
1 parent f95b416 commit 4f83816

File tree

6 files changed

+68
-6
lines changed

6 files changed

+68
-6
lines changed

xla/pjrt/c/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# PJRT C API changelog
22

3+
## 0.35 (Oct 20, 2023)
4+
5+
* Added PJRT_Executable_Fingerprint method
6+
* Deprecated PJRT_LoadedExecutable_Fingerprint
7+
38
## 0.34 (Oct 9, 2023)
49

510
* Added PJRT_Structure_Type::PJRT_Structure_Type_Profiler.

xla/pjrt/c/pjrt_c_api.h

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ extern "C" {
5353
// Changes include:
5454
// * Adding a new field to the PJRT_Api or argument structs
5555
// * Renaming a method or argument (doesn't affect ABI)
56-
#define PJRT_API_MINOR 34
56+
#define PJRT_API_MINOR 35
5757

5858
// The plugin should set the major_version and minor_version of
5959
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
@@ -1315,6 +1315,24 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_SizeOfGeneratedCodeInBytes_Args,
13151315
typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
13161316
PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);
13171317

1318+
struct PJRT_Executable_Fingerprint_Args {
1319+
size_t struct_size;
1320+
void* priv;
1321+
PJRT_Executable* executable;
1322+
// Has the lifetime of `executable`
1323+
const char* executable_fingerprint; // out
1324+
size_t executable_fingerprint_size; // out
1325+
};
1326+
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Fingerprint_Args,
1327+
executable_fingerprint_size);
1328+
1329+
// A unique fingerprint for `executable`. Two executables that were produced by
1330+
// compiling with identical inputs (same program, compile options, compiler
1331+
// version, etc.) should have the same fingerprint. May not be implemented by
1332+
// all platforms.
1333+
typedef PJRT_Error* PJRT_Executable_Fingerprint(
1334+
PJRT_Executable_Fingerprint_Args* args);
1335+
13181336
struct PJRT_Executable_GetCostAnalysis_Args {
13191337
size_t struct_size;
13201338
void* priv;
@@ -1434,10 +1452,11 @@ struct PJRT_LoadedExecutable_Fingerprint_Args {
14341452
};
14351453
PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Fingerprint_Args,
14361454
executable_fingerprint_size);
1437-
// A unique fingerprint for `executable`. Two executables that were produced by
1438-
// compiling with identical inputs (same program, compile options, compiler
1439-
// version, etc.) should have the same fingerprint. May not be implemented by
1440-
// all platforms.
1455+
// DEPRECATED. Will be removed in PJRT version 2.0. Please use
1456+
// PJRT_Executable_Fingerprint instead. A unique fingerprint for `executable`.
1457+
// Two executables that were produced by compiling with identical inputs (same
1458+
// program, compile options, compiler version, etc.) should have the same
1459+
// fingerprint. May not be implemented by all platforms.
14411460
typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
14421461
PJRT_LoadedExecutable_Fingerprint_Args* args);
14431462

@@ -2090,6 +2109,8 @@ typedef struct {
20902109
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory);
20912110

20922111
_PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer);
2112+
2113+
_PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint);
20932114
} PJRT_Api;
20942115

20952116
const size_t PJRT_Api_STRUCT_SIZE =

xla/pjrt/c/pjrt_c_api_wrapper_impl.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ PJRT_Error* PJRT_Client_LookupAddressableDevice(
416416
return nullptr;
417417
}
418418

419+
// TODO: b/306669267 - this method is deprecated. When can we return
420+
// unimplemented?
419421
PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
420422
PJRT_LoadedExecutable_Fingerprint_Args* args) {
421423
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
@@ -1115,6 +1117,18 @@ PJRT_Error* PJRT_Executable_OptimizedProgram(
11151117
}
11161118
}
11171119

1120+
PJRT_Error* PJRT_Executable_Fingerprint(
1121+
PJRT_Executable_Fingerprint_Args* args) {
1122+
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
1123+
"PJRT_Executable_Fingerprint_Args",
1124+
PJRT_Executable_Fingerprint_Args_STRUCT_SIZE, args->struct_size));
1125+
PJRT_RETURN_IF_ERROR(args->executable->fingerprint.status());
1126+
args->executable_fingerprint = args->executable->fingerprint.value().c_str();
1127+
args->executable_fingerprint_size =
1128+
args->executable->fingerprint.value().size();
1129+
return nullptr;
1130+
}
1131+
11181132
PJRT_Error* PJRT_Executable_GetCostAnalysis(
11191133
PJRT_Executable_GetCostAnalysis_Args* args) {
11201134
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
@@ -2154,7 +2168,8 @@ PJRT_TopologyDescription* CreateWrapperDeviceTopology(
21542168

21552169
PJRT_Executable::PJRT_Executable(
21562170
std::shared_ptr<xla::PjRtExecutable> executable)
2157-
: executable(std::move(executable)) {}
2171+
: executable(std::move(executable)),
2172+
fingerprint(this->executable->FingerprintExecutable()) {}
21582173

21592174
PJRT_LoadedExecutable::PJRT_LoadedExecutable(
21602175
std::shared_ptr<xla::PjRtLoadedExecutable> executable, PJRT_Client* client)

xla/pjrt/c/pjrt_c_api_wrapper_impl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ struct PJRT_Executable {
9393
// Must be shared_ptr so that we can share with PJRT_LoadedExecutable.
9494
std::shared_ptr<xla::PjRtExecutable> executable;
9595

96+
xla::StatusOr<std::string> fingerprint;
97+
9698
// Used to synchronize concurrent setting of cached values.
9799
mutable absl::Mutex mutex;
98100

@@ -262,6 +264,7 @@ PJRT_Error* PJRT_LoadedExecutable_AddressableDevices(
262264
PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args);
263265
PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
264266
PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);
267+
PJRT_Error* PJRT_Executable_Fingerprint(PJRT_Executable_Fingerprint_Args* args);
265268
PJRT_Error* PJRT_Executable_GetCostAnalysis(
266269
PJRT_Executable_GetCostAnalysis_Args* args);
267270
PJRT_Error* PJRT_Executable_OutputElementTypes(
@@ -286,6 +289,8 @@ PJRT_Error* PJRT_Executable_DeserializeAndLoad(
286289
PJRT_Executable_DeserializeAndLoad_Args* args);
287290
PJRT_Error* PJRT_LoadedExecutable_GetExecutable(
288291
PJRT_LoadedExecutable_GetExecutable_Args* args);
292+
// TODO: b/306669267 - this method is deprecated. When can we return
293+
// unimplemented?
289294
PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
290295
PJRT_LoadedExecutable_Fingerprint_Args* args);
291296

@@ -563,6 +568,7 @@ constexpr PJRT_Api CreatePjrtApi(
563568
pjrt::PJRT_Buffer_CopyToMemory,
564569
/*PJRT_Client_CreateViewOfDeviceBuffer=*/
565570
pjrt::PJRT_Client_CreateViewOfDeviceBuffer,
571+
/*PJRT_Executable_Fingerprint=*/pjrt::PJRT_Executable_Fingerprint,
566572
};
567573
}
568574

xla/pjrt/pjrt_c_api_client.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,19 @@ StatusOr<std::string> PjRtCApiExecutable::SerializeExecutable() const {
11401140
return std::string(ser_args.serialized_bytes, ser_args.serialized_bytes_size);
11411141
}
11421142

1143+
StatusOr<std::string> PjRtCApiExecutable::FingerprintExecutable() const {
1144+
PJRT_Executable_Fingerprint_Args args;
1145+
args.struct_size = PJRT_Executable_Fingerprint_Args_STRUCT_SIZE;
1146+
args.priv = nullptr;
1147+
args.executable = c_executable();
1148+
1149+
RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Executable_Fingerprint(&args),
1150+
c_api_);
1151+
1152+
return std::string(args.executable_fingerprint,
1153+
args.executable_fingerprint_size);
1154+
}
1155+
11431156
// ------------------------ Loaded Executables ---------------------------------
11441157

11451158
PjRtCApiLoadedExecutable::PjRtCApiLoadedExecutable(

xla/pjrt/pjrt_c_api_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,8 @@ class PjRtCApiExecutable : public PjRtExecutable {
518518

519519
StatusOr<std::string> SerializeExecutable() const override;
520520

521+
StatusOr<std::string> FingerprintExecutable() const override;
522+
521523
private:
522524
const PJRT_Api* c_api_;
523525
std::unique_ptr<PJRT_Executable, ::pjrt::PJRT_ExecutableDeleter> executable_;

0 commit comments

Comments
 (0)