Skip to content

Commit 28a958a

Browse files
author
mochen.bmc
committed
Resize Bicubic
1 parent 4f83816 commit 28a958a

15 files changed

+821
-7
lines changed

xla/debug_options_flags.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
794794
"this flag to false."));
795795
flag_list->push_back(tsl::Flag(
796796
"xla_multiheap_size_constraint_per_heap",
797-
int32_setter_for(
797+
int64_setter_for(
798798
&DebugOptions::set_xla_multiheap_size_constraint_per_heap),
799799
debug_options->xla_multiheap_size_constraint_per_heap(),
800800
"Generates multiple heaps (i.e., temp buffers) with a size "

xla/pjrt/gpu/gpu_helpers.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ void EnablePeerAccess(absl::Span<se::StreamExecutor* const> executors) {
7272

7373
// Builds a BFCAllocator for all local GPUs.
7474
StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
75-
se::StreamExecutor* executor, double memory_fraction, bool preallocate) {
75+
se::StreamExecutor* executor, double memory_fraction, bool preallocate,
76+
bool garbage_collection) {
7677
bool enable_unified_memory;
7778
Status status = tsl::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY", false,
7879
&enable_unified_memory);
@@ -111,6 +112,7 @@ StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
111112

112113
tsl::BFCAllocator::Options opts;
113114
opts.allow_growth = !preallocate;
115+
opts.garbage_collection = garbage_collection;
114116
return std::make_unique<tsl::BFCAllocator>(
115117
std::move(sub_allocator), allocator_memory,
116118
absl::StrCat("GPU_", device_ordinal, "_bfc"), opts);

xla/pjrt/gpu/gpu_helpers.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,18 @@ struct GpuAllocatorConfig {
5757
// fragmentation, allowing more of the total memory to be used. If false, the
5858
// allocator will allocate more memory as allocations are requested.
5959
bool preallocate = true;
60+
61+
// activate garbage collection or not
62+
bool garbage_collection = false;
6063
};
6164

6265
std::unique_ptr<tsl::BFCAllocator> GetGpuHostAllocator(
6366
se::StreamExecutor* executor);
6467

6568
// Builds a BFCAllocator for all local GPUs.
6669
StatusOr<std::unique_ptr<tsl::BFCAllocator>> CreateBFCAllocator(
67-
se::StreamExecutor* executor, double memory_fraction, bool preallocate);
70+
se::StreamExecutor* executor, double memory_fraction, bool preallocate,
71+
bool garbage_collection);
6872

6973
} // namespace xla
7074

xla/pjrt/gpu/se_gpu_pjrt_client.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,8 @@ GetStreamExecutorGpuDeviceAllocator(
745745
auto bfc_allocator,
746746
CreateBFCAllocator(ordinal_and_device.second->executor(),
747747
allocator_config.memory_fraction,
748-
allocator_config.preallocate));
748+
allocator_config.preallocate,
749+
allocator_config.garbage_collection));
749750
allocators_and_streams.emplace_back(
750751
std::move(bfc_allocator),
751752
ordinal_and_device.second->compute_stream());

xla/service/buffer_assignment.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2019,7 +2019,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
20192019
buffers_to_assign_sequentially.size() == global_computations.size();
20202020
VLOG(2) << "Running whole module heap simulation: "
20212021
<< run_whole_module_heap_simulation;
2022-
const int32_t multiheap_size_constraint_per_heap =
2022+
const int64_t multiheap_size_constraint_per_heap =
20232023
module->config().debug_options().xla_multiheap_size_constraint_per_heap();
20242024
VLOG(2) << "Multiheap per heap size limit: "
20252025
<< multiheap_size_constraint_per_heap;

xla/service/buffer_assignment.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ class BufferAssignment {
530530
color_alignment_(std::move(color_alignment)),
531531
alias_analysis_(std::move(alias_analysis)),
532532
hlo_live_range_(std::move(hlo_live_range)) {
533-
int32_t raw_value = module->config()
533+
int64_t raw_value = module->config()
534534
.debug_options()
535535
.xla_multiheap_size_constraint_per_heap();
536536
// -1 means no constraint.

xla/service/gpu/runtime/BUILD

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ cc_library(
257257
":stream_synchronization",
258258
":support",
259259
":topk",
260+
":resize_bicubic",
260261
":tracing",
261262
"//xla:statusor",
262263
"//xla:xla_proto_cc",
@@ -417,6 +418,85 @@ cc_library(
417418
],
418419
)
419420

421+
cc_library(
422+
name = "resize_bicubic_kernel",
423+
srcs = if_cuda_is_configured(
424+
[
425+
"resize_bicubic_kernel.cc",
426+
],
427+
),
428+
hdrs = if_cuda_is_configured(["resize_bicubic_kernel.h"]),
429+
compatible_with = [],
430+
deps = [
431+
":resize_bicubic_kernel_cuda",
432+
# "//xla:shape_util",
433+
"//xla:xla_proto_cc",
434+
"//xla:xla_data_proto_cc",
435+
"//xla/runtime:memref_view",
436+
"//xla/stream_executor:platform",
437+
"//xla/stream_executor:stream_executor_headers", # build_cleaner: keep
438+
"//xla/stream_executor/gpu:gpu_stream_header",
439+
"//xla/stream_executor/gpu:gpu_types_header",
440+
"@com_google_absl//absl/numeric:bits",
441+
"@com_google_absl//absl/status",
442+
"@com_google_absl//absl/status:statusor",
443+
"@local_config_cuda//cuda:cuda_headers",
444+
],
445+
)
446+
447+
cuda_library(
448+
name = "resize_bicubic_kernel_cuda",
449+
srcs = if_cuda_is_configured(
450+
[
451+
"resize_bicubic_kernel.cu.cc",
452+
],
453+
),
454+
hdrs = if_cuda_is_configured(["resize_bicubic_kernel_common.h"]),
455+
compatible_with = [],
456+
deps = [
457+
"@eigen_archive//:eigen3",
458+
"@local_config_cuda//cuda:cuda_headers",
459+
"@com_google_absl//absl/types:span",
460+
],
461+
)
462+
463+
464+
cc_library(
465+
name = "resize_bicubic",
466+
srcs = if_cuda_is_configured(
467+
["resize_bicubic.cc"],
468+
),
469+
hdrs = ["resize_bicubic.h"],
470+
deps = if_cuda_is_configured([":resize_bicubic_kernel"]) + [
471+
":support",
472+
"//xla:executable_run_options",
473+
# "//xla:shape_util",
474+
"//xla:status",
475+
"//xla:statusor",
476+
# "//xla:types",
477+
"//xla:xla_data_proto_cc",
478+
"//xla:xla_proto_cc",
479+
"//xla/hlo/ir:hlo",
480+
# "//xla/mlir/runtime/transforms:custom_call_encoding",
481+
"//xla/runtime:custom_call",
482+
"//xla/runtime:custom_call_registry",
483+
"//xla/runtime:executable",
484+
"//xla/runtime:state",
485+
# "//xla/runtime/ffi:ffi_api",
486+
# "//xla/runtime/ffi:ffi_c_api_hdrs",
487+
"//xla/service:executable",
488+
"//xla/service:hlo_pass",
489+
"//xla/service:tuple_util",
490+
"//xla/stream_executor/gpu:gpu_stream_header",
491+
"//xla/stream_executor/gpu:gpu_types_header",
492+
"@com_google_absl//absl/container:flat_hash_set",
493+
"@com_google_absl//absl/log:check",
494+
"@com_google_absl//absl/status",
495+
"@com_google_absl//absl/strings",
496+
"@tsl//tsl/platform:statusor",
497+
],
498+
)
499+
420500
cc_library(
421501
name = "gemm",
422502
srcs = ["gemm.cc"],

xla/service/gpu/runtime/executable.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ limitations under the License.
4747
#include "xla/service/gpu/runtime/stream_synchronization.h"
4848
#include "xla/service/gpu/runtime/support.h"
4949
#include "xla/service/gpu/runtime/topk.h"
50+
#include "xla/service/gpu/runtime/resize_bicubic.h"
5051
#include "xla/service/gpu/runtime/tracing.h"
5152
#include "xla/service/gpu/thunk.h"
5253
#include "xla/service/service_executable_run_options.h"
@@ -87,6 +88,7 @@ void RegisterXlaGpuRuntimeCustomCalls(DirectCustomCallRegistry& registry) {
8788
RegisterMemsetCustomCalls(registry);
8889
RegisterSendRecvCustomCalls(registry);
8990
RegisterTopkCustomCall(registry);
91+
RegisterResizeBicubicCustomCall(registry);
9092

9193
#if GOOGLE_CUDA || TF_HIPBLASLT
9294
RegisterMatmulCustomCalls(registry);
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/service/gpu/runtime/resize_bicubic.h"
17+
18+
#include <stdint.h>
19+
20+
#include <cstddef>
21+
22+
#include "absl/status/status.h"
23+
#include "absl/types/span.h"
24+
#include "xla/runtime/custom_call.h"
25+
#include "xla/runtime/executable.h"
26+
#include "xla/service/gpu/runtime/resize_bicubic_kernel.h"
27+
// #include "xla/runtime/custom_call_registry.h"
28+
29+
#include "xla/service/gpu/runtime/support.h"
30+
#include "xla/service/service_executable_run_options.h"
31+
#include "xla/xla_data.pb.h"
32+
33+
namespace xla::gpu {
34+
using ::xla::runtime::CustomCall;
35+
using ::xla::runtime::StridedMemrefView;
36+
37+
static absl::Status ResizeBicubicImpl(
38+
const ServiceExecutableRunOptions* run_options, StridedMemrefView input,
39+
StridedMemrefView output, bool align_corners) {
40+
float scales_h =
41+
static_cast<float>(output.sizes[2]) / static_cast<float>(input.sizes[2]);
42+
float scales_w =
43+
static_cast<float>(output.sizes[3]) / static_cast<float>(input.sizes[3]);
44+
se::StreamExecutor* executor = run_options->stream()->parent();
45+
return RunResizeBicubicImpl(
46+
se::gpu::AsGpuStreamValue(run_options->stream()),
47+
executor->GetDeviceDescription().threads_per_block_limit(), input, output,
48+
align_corners, scales_h, scales_w);
49+
}
50+
51+
static absl::Status ResizeBicubicGradImpl(
52+
const ServiceExecutableRunOptions* run_options,
53+
StridedMemrefView grad_output, StridedMemrefView grad_input,
54+
bool align_corners) {
55+
float scales_h = static_cast<float>(grad_output.sizes[2]) /
56+
static_cast<float>(grad_input.sizes[2]);
57+
float scales_w = static_cast<float>(grad_output.sizes[3]) /
58+
static_cast<float>(grad_input.sizes[3]);
59+
se::StreamExecutor* executor = run_options->stream()->parent();
60+
return RunResizeBicubicGradImpl(
61+
se::gpu::AsGpuStreamValue(run_options->stream()),
62+
executor->GetDeviceDescription().threads_per_block_limit(), grad_input,
63+
grad_output, align_corners, scales_h, scales_w);
64+
}
65+
66+
XLA_RUNTIME_DEFINE_CUSTOM_CALL(
67+
ResizeBicubic, FunctionWrapper<ResizeBicubicImpl>(), checks,
68+
CustomCall::Bind("__gpu$ResizeBicubic")
69+
.UserData<const ServiceExecutableRunOptions*>()
70+
.Arg<StridedMemrefView>() // input
71+
.Arg<StridedMemrefView>() // output
72+
.Attr<bool>("align_corners"));
73+
74+
XLA_RUNTIME_DEFINE_CUSTOM_CALL(
75+
ResizeBicubicGrad, FunctionWrapper<ResizeBicubicGradImpl>(), checks,
76+
CustomCall::Bind("__gpu$ResizeBicubicGrad")
77+
.UserData<const ServiceExecutableRunOptions*>()
78+
.Arg<StridedMemrefView>() // grad_output
79+
.Arg<StridedMemrefView>() // grad_input
80+
.Attr<bool>("align_corners"));
81+
82+
void RegisterResizeBicubicCustomCall(
83+
runtime::DirectCustomCallRegistry& registry) {
84+
registry.Register("__gpu$ResizeBicubic", ResizeBicubic);
85+
registry.Register("__gpu$ResizeBicubicGrad", ResizeBicubicGrad);
86+
}
87+
88+
} // namespace xla::gpu
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_SERVICE_GPU_RUNTIME_RESIZE_BICUBIC_H_
17+
#define XLA_SERVICE_GPU_RUNTIME_RESIZE_BICUBIC_H_
18+
19+
#include "xla/runtime/custom_call_registry.h"
20+
21+
namespace xla::gpu {
22+
23+
// Registers XLA Gpu runtime TopK custom calls.
24+
void RegisterResizeBicubicCustomCall(runtime::DirectCustomCallRegistry& registry);
25+
26+
} // namespace xla::gpu
27+
28+
#endif // XLA_SERVICE_GPU_RUNTIME_RESIZE_BICUBIC_H_

0 commit comments

Comments
 (0)