Skip to content

Commit 08c81ad

Browse files
authored
[Embedding] Fix bug of saving EmbeddingVariable with int32 type. (#692)
1 parent 60d515b commit 08c81ad

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorflow/core/kernels/save_restore_v2_ops.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ class SaveV2 : public OpKernel {
171171
const string& tensor_name = tensor_names_flat(i);
172172
if (tensor_types_[i] == DT_RESOURCE) {
173173
auto& handle = HandleFromInput(context, i + kFixedInputs);
174-
if (IsHandle<EmbeddingVar<int64, float>>(handle)) {
174+
if (IsHandle<EmbeddingVar<int64, float>>(handle) ||
175+
IsHandle<EmbeddingVar<int32, float>>(handle)) {
175176
if (ev_key_types_[start_ev_key_index] == DT_INT32) {
176177
DumpEvWithGlobalStep<int32, float>(context,
177178
i + kFixedInputs, tensor_name, writer, tensor_types_[0]);

0 commit comments

Comments
 (0)