Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions build_rocm_python3
Original file line number Diff line number Diff line change
@@ -47,15 +47,15 @@ if [ -f /usertools/rocm.bazelrc ]; then
if [[ -n $nightly ]]; then
# Remove any previous builds and build nightly
rm -f $TF_PKG_LOC/tf_nightly_rocm*.whl
python3 tensorflow/tools/ci_build/update_version.py --nightly --rocm_version &&
bazel --bazelrc=/usertools/rocm.bazelrc build $RESOURCE_OPTION --config=rocm --action_env=TF_PYTHON_VERSION=$PYTHON_VERSION tensorflow/tools/pip_package:build_pip_package --verbose_failures &&
#python3 tensorflow/tools/ci_build/update_version.py --nightly --rocm_version &&
bazel --bazelrc=/usertools/rocm.bazelrc build $RESOURCE_OPTION --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-DTENSORFLOW_HSACO_USE_ROCM_LLVM" --config=v1 --config=rocm --action_env=TF_PYTHON_VERSION=$PYTHON_VERSION tensorflow/tools/pip_package:build_pip_package --verbose_failures &&
./bazel-bin/tensorflow/tools/pip_package/build_pip_package $TF_PKG_LOC --rocm --nightly_flag &&
pip3 install --upgrade $TF_PKG_LOC/tf_nightly_rocm*.whl
else
# Remove any previous builds and build release
rm -f $TF_PKG_LOC/tensorflow*.whl
python3 tensorflow/tools/ci_build/update_version.py --rocm_version &&
bazel --bazelrc=/usertools/rocm.bazelrc build $RESOURCE_OPTION --config=rocm --action_env=TF_PYTHON_VERSION=$PYTHON_VERSION tensorflow/tools/pip_package:build_pip_package --verbose_failures &&
#python3 tensorflow/tools/ci_build/update_version.py --rocm_version &&
bazel --bazelrc=/usertools/rocm.bazelrc build $RESOURCE_OPTION --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-DTENSORFLOW_HSACO_USE_ROCM_LLVM" --config=v1 --config=rocm --action_env=TF_PYTHON_VERSION=$PYTHON_VERSION tensorflow/tools/pip_package:build_pip_package --verbose_failures &&
./bazel-bin/tensorflow/tools/pip_package/build_pip_package $TF_PKG_LOC --rocm --project_name tensorflow_rocm &&
pip3 install --upgrade $TF_PKG_LOC/tensorflow*.whl
fi
@@ -66,13 +66,13 @@ else
if [[ -n $nightly ]]; then
# Remove any previous builds and build nightly
rm -f $TF_PKG_LOC/tf_nightly_rocm*.whl
bazel build $RESOURCE_OPTION --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures &&
bazel build $RESOURCE_OPTION --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-DTENSORFLOW_HSACO_USE_ROCM_LLVM" --config=v1 --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures &&
bazel-bin/tensorflow/tools/pip_package/build_pip_package $TF_PKG_LOC --rocm --nightly_flag &&
pip3 install --upgrade $TF_PKG_LOC/tf_nightly_rocm*.whl
else
# Remove any previous builds and build release
rm -f $TF_PKG_LOC/tensorflow*.whl
bazel build $RESOURCE_OPTION --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures &&
bazel build $RESOURCE_OPTION --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-DTENSORFLOW_HSACO_USE_ROCM_LLVM" --config=v1 --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures &&
bazel-bin/tensorflow/tools/pip_package/build_pip_package $TF_PKG_LOC --rocm &&
pip3 install --upgrade $TF_PKG_LOC/tensorflow*.whl
fi
15 changes: 9 additions & 6 deletions tensorflow/compiler/xla/service/gpu/autotuner_util.cc
Original file line number Diff line number Diff line change
@@ -56,16 +56,19 @@ static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) =

namespace {

void CSVLegend(std::ostream& os) {
void CSVLegend(std::ostream& os, bool full_string=false) {

os << kCsvComment << " m" << kCsvSep << "n" << kCsvSep << "k" << kCsvSep
<< "batch_count" << kCsvSep << "trans_a" << kCsvSep
<< "trans_b" << kCsvSep
<< "type_a" << kCsvSep << "type_b" << kCsvSep
<< "trans_b" << kCsvSep << "type_a" << kCsvSep << "type_b" << kCsvSep
<< "type_c" << kCsvSep << "lda" << kCsvSep << "ldb" << kCsvSep
<< "ldc" << kCsvSep << "stride_a" << kCsvSep
<< "stride_b" << kCsvSep << "stride_c" << kCsvSep
<< "alg_index" << std::endl;
<< "stride_b" << kCsvSep << "stride_c";
if (full_string) {
os << kCsvSep << "alpha_re" << kCsvSep << "alpha_im" << kCsvSep
<< "beta" << kCsvSep << "epilogue";
}
os << kCsvSep << "alg_index" << std::endl;
}

} // namespace
@@ -89,7 +92,7 @@ void CSVLegend(std::ostream& os) {
if (!s_dump_fs->is_open()) {
LOG(WARNING) << "Unable to open: " << dump_path << " for writing!";
}
CSVLegend(*s_dump_fs);
CSVLegend(*s_dump_fs, true);
}
*s_dump_fs << key.Get() << kCsvSep << it->second << std::endl;
}
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc
Original file line number Diff line number Diff line change
@@ -367,7 +367,7 @@ StatusOr<bool> RunOnInstruction(HloInstruction* gemm,

GemmAutotuner autotuner(config);
TF_ASSIGN_OR_RETURN(auto new_algorithm,
AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, false), config,
AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, true), config,
[&]() -> StatusOr<AutotunerUtil::CacheValue> {
TF_ASSIGN_OR_RETURN(auto algo, autotuner(gemm, gemm_config));
return algo.has_gemm() ? algo.gemm().algorithm() : se::blas::kDefaultAlgorithm;
@@ -410,7 +410,7 @@ StatusOr<AutotunerUtil::CacheValue> GemmAlgorithmPicker::RunStandalone(
GemmAutotuner autotuner(config_);
GemmConfig gemm_config{cfg};

return AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, false), config_,
return AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, true), config_,
[&]() -> StatusOr<AutotunerUtil::CacheValue> {
TF_ASSIGN_OR_RETURN(auto algo, autotuner(gemm_config, std::move(input_shapes),
output_shape, debug_options));
Original file line number Diff line number Diff line change
@@ -173,7 +173,9 @@ auto CublasLtMatmulThunk::GetCachedMatmulPlan(
return std::move(plan);
}
}
return InternalError("Wrong algorithm ID: %d", algorithm_id);
TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[0]));
LOG(WARNING) << "Wrong algorithm ID: " << algorithm_id << " use default instead.";
return std::move(plan);
};
return cache.GetOrCreate(canonical_hlo_, create);
}
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
Original file line number Diff line number Diff line change
@@ -65,6 +65,7 @@ cc_library(
"@llvm-project//llvm:Target",
] + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
"//tensorflow/tsl/platform:rocm_rocdl_path",
"@llvm-project//llvm:AMDGPUCodeGen",
]),
)
424 changes: 339 additions & 85 deletions tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions tensorflow/compiler/xla/service/gpu/target_constants.h
Original file line number Diff line number Diff line change
@@ -46,9 +46,7 @@ inline const char* TargetTriple() {
// The data layout of the emitted module.
inline const char* DataLayout() {
static constexpr char kDataLayout[] =
"e-p:64:64-p1:64:64-p2:64:64-p3:32:32-p4:32:32-p5:32:32"
"-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128"
"-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-A5";
"e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9";
return kDataLayout;
}

3 changes: 1 addition & 2 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc
Original file line number Diff line number Diff line change
@@ -264,8 +264,7 @@ std::string ToCSVString(const GemmConfig& cfg, bool full_string) {

if (full_string) {
// NOTE: epilogue is required for MatmulPlan caching !
oss //<< kCsvSep << cfg.alpha << kCsvSep << cfg.beta
<< kCsvSep << (int64_t)cfg.epilogue;
oss << kCsvSep << cfg.alpha.real() << kCsvSep << cfg.alpha.imag() << kCsvSep << cfg.beta << kCsvSep << (int64_t)cfg.epilogue;
}

return oss.str();
158 changes: 158 additions & 0 deletions tensorflow/tools/ci_build/Dockerfile.cs8.rocm
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# This Dockerfile provides a starting point for a ROCm installation of
# MIOpen and tensorflow.
FROM almalinux:8
MAINTAINER Jeff Poznanovic <jeffrey.poznanovic@amd.com>

ARG RPM_ROCM_REPO=https://repo.radeon.com/rocm/rhel8/.yum_6.3.0.1/main
ARG ROCM_PATH=/opt/rocm-6.3.0.1

ENV ROCM_PATH=$ROCM_PATH
ENV DEBIAN_FRONTEND noninteractive
ENV TF_NEED_ROCM 1
ENV GCC_HOST_COMPILER_PATH=/opt/rocm/bin/amdclang
ENV HOME /root/
RUN dnf update -y && dnf install -y epel-release && dnf install -y elrepo-release && dnf config-manager --set-enabled powertools
# Setup the build_system repo
RUN echo -e "[build_system]\nname=ROCm\nbaseurl=https://repo.almalinux.org/build_system/8/x86_64/\nenabled=1\ngpgcheck=0" >/etc/yum.repos.d/build_system.repo
RUN dnf group install -y "Development Tools"

RUN bin/bash -c 'echo -e "[ROCm]\nname=ROCm\nbaseurl=$RPM_ROCM_REPO\nenabled=1\ngpgcheck=0" >>/etc/yum.repos.d/rocm.repo'
RUN bin/bash -c 'echo -e "[amdgpu]\nname=amdgpu\nbaseurl=https://repo.radeon.com/amdgpu/.6.3.0.1/rhel/8.8/main/x86_64/\nenabled=1\ngpgcheck=0" >> /etc/yum.repos.d/amdgpu.repo'

RUN dnf clean all
RUN dnf update -y

# Install misc pkgs
RUN dnf --enablerepo=extras,epel,elrepo,powertools,build_system install -y \
epel-release \
openssl-devel \
libffi-devel \
hdf5-devel \
wget \
make \
patch \
zlib-devel \
bzip2 \
bzip2-devel \
readline \
readline-devel \
sqlite \
sqlite-devel \
openssl-devel \
tk-devel \
xz-devel

RUN dnf --enablerepo=extras,epel,elrepo,powertools,build_system install -y \
bc \
bridge-utils \
cmake \
cmake3 \
devscripts \
dkms \
doxygen \
dpkg \
dpkg-dev \
dpkg-perl \
elfutils-libelf-devel \
expect \
file \
gettext \
gcc-c++ \
git \
libgcc \
ncurses \
ncurses-base \
ncurses-libs \
numactl-devel \
numactl-libs \
libssh \
libunwind-devel \
libunwind \
llvm \
llvm-libs \
make \
openssl \
openssl-libs \
openssh \
openssh-clients \
pciutils \
pciutils-devel \
pciutils-libs \
java-11-openjdk-devel \
patchelf\
pkgconfig \
npth \
qemu-kvm \
re2c \
rpm \
rpm-build \
subversion \
sudo \
wget\
kernel-devel-uname-r

RUN dnf --enablerepo=extras,build_system install -y \
libdrm-amdgpu \
rocm-dev \
rocm-ml-sdk \
miopen-hip \
miopen-hip-devel \
rocblas \
rocblas-devel \
rocsolver-devel \
rocrand-devel \
rocfft-devel \
hipfft-devel \
hipblas-devel \
rocprim-devel \
hipcub-devel \
rccl-devel \
hipsparse-devel \
hipsolver-devel \
hipblas-common-devel \
rocm-llvm-devel \
boost-devel

RUN dnf --enablerepo=extras,epel,elrepo,powertools,build_system install -y \
python3.11 \
python3.11-devel \
python3.11-pip \
python3.11-wheel

RUN ln -sf /usr/bin/python3.11 /usr/bin/python3
RUN ln -sf /usr/bin/python3 /usr/bin/python
RUN ln -sf /usr/bin/python3.11 /etc/alternatives/python3

RUN python3 -m ensurepip
RUN pip install joblib numpy==1.24.0 requests packaging

ENV OPENCL_ROOT=$ROCM_PATH/opencl
ENV PATH="$ROCM_PATH/bin:${PATH}"
ENV PATH="$OPENCL_ROOT/bin:${PATH}"

# Workaround, explicitly add symbolic link to /opt/rocm
RUN touch ${ROCM_PATH}/.info/version
RUN bash -c 'ln -s ${ROCM_PATH} /opt/rocm'

# Add target file to help determine which device(s) to build for
RUN bash -c 'echo -e "gfx942\ngfx90a\n" >> ${ROCM_PATH}/bin/target.lst'

# Setup environment variables, and add those environment variables at the end of ~/.bashrc
ARG PATH=$HCC_HOME/bin:$HIP_PATH/bin:$PATH

COPY install/*.sh /install/

SHELL ["/bin/bash", "-c"]
RUN /install/install_bazel.sh
RUN /install/install_golang.sh

# Configure the build for our CUDA configuration.
ENV TF_NEED_ROCM 1

# This is a temporary workaround to fix Out-Of-Memory errors we are running into with XLA perf tests
# By default, HIP runtime "hides" 256MB from the TF Runtime, but with recent changes (update to ROCm2.3, dynamic loading of roc* libs, et al)
# it seems that we need to up the threshold slightly to 320MB
ENV HIP_HIDDEN_FREE_MEM=320

#We'll be using a custom CK build in this branch
# RUN bash -c 'mv ${ROCM_PATH}/include/ck ${ROCM_PATH}/include/ck-back'
13 changes: 13 additions & 0 deletions tensorflow/tools/hlo_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# A script about how to compute HLO Module FLOPS.
## Build
```
bazel build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --config=v1 --config=opt --config=rocm tensorflow/compiler/xla/tools:run_hlo_module tensorflow/compiler/xla/tools:compute_cost --verbose_failures
```

## Usage
```
python tensorflow/tools/hlo_benchmark/hlo_estimate.py --hlo=tensorflow/compiler/xla/tests/*.gfx942_gpu_after_optimizations.txt --output=result.txt
```

## Example Output
slow_xla_sample_v2/module_5629.cluster_4221__XlaCompiledKernel_true__XlaHasReferenceVars_true__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.785.gfx942_gpu_after_optimizations.txt 4332.0634140346565 GFLOPS/s 0.0008137 s
75 changes: 75 additions & 0 deletions tensorflow/tools/hlo_benchmark/hlo_estimate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import subprocess
import glob
import re
import argparse

# Paths to the input and output files
parser = argparse.ArgumentParser(description="""Generate Tensile config file""")

parser.add_argument(
"--hlo",
type=str,
help="Glob path to hlo modules")

parser.add_argument(
"--output",
type=str,
help="Output file path")

parser.add_argument(
"--warmup", type=int, default=10,
help="Warmup iterations")

parser.add_argument(
"--iters", type=int, default=10,
help="Max tuning iterations")

args = parser.parse_args()


# PATH = "/home/sixifang/tensorflow-upstream/bubble_test_xla_dump/*.gfx90a_gpu_after_optimizations.txt"
# OUTPUT_FILE = "result.txt"
PATH = args.hlo
OUTPUT_FILE = args.output

HLO_BENCH_RE = r"execution time for runner ROCM: (?P<TIME>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?)"
FLOPS_RE = (r"(?P<FLOPS>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?) GFLOPS. "
r"(?P<BYTES>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?) MiB.")


files = glob.glob(PATH)
with open(OUTPUT_FILE, 'w') as f:
for file in files:
res = subprocess.run(f"bazel-bin/tensorflow/compiler/xla/tools/run_hlo_module --reference_platform='' --xla_disable_all_hlo_passes=true --iterations={args.warmup+args.iters} --platform=gpu {file}", shell=True, capture_output=True)
lines = res.stderr.decode('utf-8').split('\n')
times = []
for line in lines:
match = re.search(
HLO_BENCH_RE, line
)

if match:
time = float(match.group('TIME').strip())
times.append(time)
times = times[args.warmup:]
avg_time = 0
if len(times)>0:
avg_time = sum(times) / len(times)

res = subprocess.run(f"bazel-bin/tensorflow/compiler/xla/tools/compute_cost --format=hlo --input={file}", shell=True, capture_output=True)
lines = res.stdout.decode('utf-8').split('\n')
match = re.search(FLOPS_RE, lines[0])
flops = 0
bytes = 0
for line in lines:
match = re.search(FLOPS_RE, line)
if match:
flops = float(match.group('FLOPS').strip())
bytes = float(match.group('BYTES').strip())
break
tflops = 0
if avg_time > 0 and flops > 0:
tflops = flops / avg_time

f.write(f"{file} {tflops} GFLOPS/s {avg_time} s {bytes} MiB\n")
f.flush()
4 changes: 2 additions & 2 deletions third_party/gpus/crosstool/BUILD.rocm.tpl
Original file line number Diff line number Diff line change
@@ -87,14 +87,14 @@ cc_toolchain_config(
"-fuse-ld=gold",
"-Wl,-no-as-needed",
"-Wl,-z,relro,-z,now",
"-pass-exit-codes",
# "-pass-exit-codes",
"-lstdc++",
"-lm",
],
link_libs = [],
opt_link_flags = [],
unfiltered_compile_flags = [
"-fno-canonical-system-headers",
# "-fno-canonical-system-headers",
"-Wno-builtin-macro-redefined",
"-D__DATE__=\"redacted\"",
"-D__TIMESTAMP__=\"redacted\"",
Original file line number Diff line number Diff line change
@@ -75,7 +75,9 @@ def GetHostCompilerOptions(argv):
parser.add_argument('-iquote', nargs='*', action='append')
parser.add_argument('--sysroot', nargs=1)
parser.add_argument('-g', nargs='*', action='append')
parser.add_argument('-fno-canonical-system-headers', action='store_true')
parser.add_argument('-no-canonical-prefixes', action='store_true')
parser.add_argument('-Wno-unused-variable', action='store_true')
parser.add_argument('-Wno-unused-but-set-variable', action='store_true')

args, _ = parser.parse_known_args(argv)

@@ -87,10 +89,16 @@ def GetHostCompilerOptions(argv):
opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
if args.g:
opts += ' -g' + ' -g'.join(sum(args.g, []))
if args.fno_canonical_system_headers:
if args.no_canonical_prefixes:
opts += ' -no-canonical-prefixes'
if args.sysroot:
opts += ' --sysroot ' + args.sysroot[0]
if args.Wno_unused_variable:
opts += ' -Wno-unused-variable'

if args.Wno_unused_but_set_variable:
opts += ' -Wno-unused-but-set-variable'


return opts

@@ -282,7 +290,13 @@ def main():
if not flag.startswith(('--rocm_log'))]

# XXX: SE codes need to be built with gcc, but need this macro defined
cpu_compiler_flags.append("-D__HIP_PLATFORM_HCC__")
cpu_compiler_flags.append("-D__HIP_PLATFORM_AMD__")
cpu_compiler_flags.append('-L' + HIP_RUNTIME_PATH)
cpu_compiler_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH)
cpu_compiler_flags.append('-l' + HIP_RUNTIME_LIBRARY)
cpu_compiler_flags.append("-lrt")
cpu_compiler_flags.append("-Wno-unused-command-line-argument")
cpu_compiler_flags.append("-Wno-gnu-offsetof-extensions")
if VERBOSE: print(' '.join([CPU_COMPILER] + cpu_compiler_flags))
return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)

Original file line number Diff line number Diff line change
@@ -1046,7 +1046,7 @@ def _impl(ctx):
flag_group(
flags = [
"-no-canonical-prefixes",
"-fno-canonical-system-headers",
#"-fno-canonical-system-headers",
]
),
],
4 changes: 3 additions & 1 deletion third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
@@ -715,12 +715,14 @@ def _create_local_rocm_repository(repository_ctx):
# .d file - given that includes that are prefixed with "../" multiple
# time quickly grow longer than the root of the tree, this can lead to
# bazel's header check failing.
rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
rocm_defines["%{extra_no_canonical_prefixes_flags}"] = ""

rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([
"-DTENSORFLOW_USE_ROCM=1",
"-D__HIP_PLATFORM_AMD__",
"-DEIGEN_USE_HIP",
"-Wno-unused-but-set-variable",
"-Wno-c++11-narrowing",
])

rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"