Skip to content

add half gemm support #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3998ebd
hgemm: add half definition and interfaces
cayrols Sep 22, 2023
60afc47
hgemm: add hgemm tester; conversion routine from and to half; Nvidia …
cayrols Sep 22, 2023
f51f0c2
utils: move utils into test/cuda; Add missing NVCC flag; Clean utils …
cayrols Sep 25, 2023
98e452b
utils: move utils into test/cuda; Add missing NVCC flag; Clean utils …
cayrols Sep 25, 2023
64e6697
half: Add hipify from the cuda utils; update the compilation chain to…
cayrols Sep 27, 2023
c7061d0
Replace the definition of half from __half to _float16; rename blas::…
cayrols Oct 4, 2023
68cb36d
Rename test/utils.cuh into test/utils.hh
cayrols Oct 4, 2023
1c99c49
Add hip files that got generated from cuda files.
cayrols Oct 4, 2023
7edd364
hgemm: Add hip support for hgemm
cayrols Oct 4, 2023
ffac1ae
hgemm: fix compilation after cleaning.
cayrols Oct 4, 2023
bd9511d
test: gemm change the bound by removing sqrt.
cayrols Oct 18, 2023
44a9e73
test: fix scalar type used to get the flop count in half gemm.
cayrols Oct 18, 2023
82752bf
hgemm: Use a class float16 instead of an alias.
cayrols Oct 18, 2023
b3c836c
config: fix gpu_backend name issue.
cayrols Oct 18, 2023
6a181e8
hgemm: fix compilation issue.
cayrols Oct 18, 2023
3ced0ce
TMP: Add explicit compilation flag for reproducer purpose.
cayrols Oct 18, 2023
d1fd85c
hgemm: Search _Float16 support from compiler; If so, the macro BLAS_U…
cayrols Nov 29, 2023
8b8320b
hgemm: add CPU support through MKL.
cayrols Nov 28, 2023
30cd556
hgemm: add CPU test with cblas wrapper; add cast_onto_device util tha…
cayrols Nov 28, 2023
d686451
float16: update configure search and macro definition.
cayrols Dec 1, 2023
5a30ad8
hgemm: Fake casting in cublas_wrapper through pointer casting.
cayrols Dec 1, 2023
f4fed1a
test: in hgemm, add cpu casting support, remove cast_onto_device rout…
cayrols Dec 1, 2023
d0d2867
float16: add casting routines from/to fp16 to/from fp32.
cayrols Dec 5, 2023
70d893b
float16: add missing config file.
cayrols Dec 6, 2023
8678c82
hgemm: enable CPU hgemm only when MKL is provided.
cayrols Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ make.inc:
RANLIB ?= ranlib
prefix ?= /opt/slate

NVCC ?= nvcc
HIPCC ?= hipcc
hipify ?= hipify-perl
md5sum ?= tools/md5sum.pl

NVCCFLAGS += -O3 -std=c++11 --compiler-options '-Wall -Wno-unused-function'
HIPCCFLAGS += -std=c++11 -DTCE_HIP -fno-gpu-rdc

abs_prefix := ${abspath ${prefix}}

# Default LD=ld won't work; use CXX. Can override in make.inc or environment.
Expand All @@ -52,11 +60,25 @@ ifneq ($(findstring darwin, $(ostype)),)
macos = 1
endif

#-------------------------------------------------------------------------------
# Detect which gpu_backend used
cuda = 0
hip = 0
sycl = 0

ifeq ($(gpu_backend),cuda)
cuda = 1
else ifeq ($(gpu_backend),hip)
hip = 1
endif

#-------------------------------------------------------------------------------
# if shared
ifneq ($(static),1)
CXXFLAGS += -fPIC
LDFLAGS += -fPIC
NVCCFLAGS += --compiler-options '-fPIC'
HIPCCFLAGS += -fPIC
lib_ext = so
else
lib_ext = a
Expand All @@ -77,7 +99,19 @@ lib_src = $(wildcard src/*.cc)
lib_obj = $(addsuffix .o, $(basename $(lib_src)))
dep += $(addsuffix .d, $(basename $(lib_src)))

cuda_src = $(wildcard test/cuda/*.cu)
hip_src = $(patsubst test/cuda/%.cu,test/hip/%.hip.cc,$(cuda_src))

tester_src = $(wildcard test/*.cc)

ifeq ($(cuda),1)
tester_src += $(cuda_src)
endif

ifeq ($(hip),1)
tester_src += $(hip_src)
endif

tester_obj = $(addsuffix .o, $(basename $(tester_src)))
dep += $(addsuffix .d, $(basename $(tester_src)))

Expand Down Expand Up @@ -123,6 +157,8 @@ src/version.o: .id
#-------------------------------------------------------------------------------
# BLAS++ specific flags and libraries
CXXFLAGS += -I./include
NVCCFLAGS += -I./include
HIPCCFLAGS += -I./include

# additional flags and libraries for testers
$(tester_obj): CXXFLAGS += -I$(testsweeper_dir)
Expand Down Expand Up @@ -158,6 +194,59 @@ uninstall:
$(RM) $(DESTDIR)$(abs_prefix)/lib$(LIB_SUFFIX)/libblaspp.*
$(RM) $(DESTDIR)$(abs_prefix)/lib$(LIB_SUFFIX)/pkgconfig/blaspp.pc

#-------------------------------------------------------------------------------
# HIP sources converted from CUDA sources.

# if_md5_outdated applies the given build rule ($1) only if the md5 sums
# of the target's dependency ($<) doesn't match that stored in the
# target's dep file ([email protected]). If the target ($@) is already up-to-date
# based on md5 sums, its timestamp is updated so make will recognize it
# as up-to-date. Otherwise, the target is built and its dep file
# updated. Instead of depending on the src file, the target depends on
# the md5 file of the src file. This can be adapted for multiple dependencies.
# Example usage:
#
# %: %.c.md5
# ${call if_md5_outdated,\
# gcc -o $@ ${basename $<}}
#
define if_md5_outdated
if [ -e $@ ] && diff $< [email protected] > /dev/null 2>&1; then \
echo " make: '$@' is up-to-date based on md5sum."; \
echo " touch $@"; \
touch $@; \
else \
echo " make: '$@' is out-of-date based on md5sum."; \
echo " ${strip $1}"; \
$1; \
cp $< [email protected]; \
fi
endef

# From GNU manual: Commas ... cannot appear in an argument as written.
# The[y] can be put into the argument value by variable substitution.
comma := ,

# Convert CUDA => HIP code.
# Explicitly mention ${hip_src}, ${hip_hdr}, ${md5_files}
# to prevent them from being intermediate files,
# so they are _always_ generated and never removed.
# Perl updates includes and removes excess spaces that fail style hook.
${hip_src}: test/hip/%.hip.cc: test/cuda/%.cu.md5 | test/hip
@${call if_md5_outdated, \
${hipify} ${basename $<} > $@; \
perl -pi -e 's/\.cuh/.hip.hh/g; s/ +(${comma}|;|$$)/$$1/g;' $@}

hipify: ${hip_src}

md5_files := ${addsuffix .md5, ${cuda_src}}

${md5_files}: %.md5: %
${md5sum} $< > $@

test/hip:
mkdir -p $@

#-------------------------------------------------------------------------------
# if re-configured, recompile everything
$(lib_obj) $(tester_obj): make.inc
Expand Down Expand Up @@ -286,9 +375,16 @@ hooks: ${hooks}
cp $< $@ ; \
fi

# .hip.cc rule before .cc rule.
%.hip.o: %.hip.cc
$(HIPCC) $(HIPCCFLAGS) -c $< -o $@

%.o: %.cc
$(CXX) $(CXXFLAGS) -c $< -o $@

%.o: %.cu
$(NVCC) $(NVCCFLAGS) -c $< -o $@

# preprocess source
%.i: %.cc
$(CXX) $(CXXFLAGS) -I$(testsweeper_dir) -E $< -o $@
Expand Down Expand Up @@ -333,6 +429,24 @@ echo:
@echo
@echo "dep = $(dep)"
@echo
@echo "---------- CUDA options"
@echo "cuda = '$(cuda)'"
@echo "NVCC = $(NVCC)"
@echo "NVCC_which = $(NVCC_which)"
@echo "CUDA_PATH = $(CUDA_PATH)"
@echo "NVCCFLAGS = $(NVCCFLAGS)"
@echo
@echo "---------- HIP options"
@echo "hip = '$(hip)'"
@echo "HIPCC = $(HIPCC)"
@echo "HIPCC_which = $(HIPCC_which)"
@echo "ROCM_PATH = $(ROCM_PATH)"
@echo "HIPCCFLAGS = $(HIPCCFLAGS)"
@echo "hipify = ${hipify}"
@echo "cuda_src = ${cuda_src}"
@echo "hip_src = ${hip_src}"
@echo "md5_files = $(md5_files)"
@echo
@echo "testsweeper_dir = $(testsweeper_dir)"
@echo "testsweeper_src = $(testsweeper_src)"
@echo "testsweeper = $(testsweeper)"
Expand Down
23 changes: 22 additions & 1 deletion config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def define( var, value=None ):

# ------------------------------------------------------------------------------
# variables to replace instead of appending/prepending
replace_vars = ['CC', 'CXX', 'NVCC', 'FC', 'AR', 'RANLIB', 'prefix']
replace_vars = ['CC', 'CXX', 'NVCC', 'FC', 'AR', 'RANLIB', 'prefix', 'gpu_backend']

# ------------------------------------------------------------------------------
# map file extensions to languages
Expand Down Expand Up @@ -615,6 +615,24 @@ def openmp( flags=['-fopenmp', '-qopenmp', '-openmp', '-omp', ''] ):
# end
# end

#-------------------------------------------------------------------------------
def float16( ):
'''
Tests for _Float16 support from the compiler.
'''
print_header( '_Float16 support' )
src = 'config/return_float16.cc'
cxxflags = define('HAVE_ISO_FLOAT16')
print_test( cxxflags )
env = {'CXXFLAGS': cxxflags}
(rc, out, err) = compile_run( src, env )
print_result( "_Float16", rc )
if (rc == 0):
environ.merge( env )
else:
print_msg( font.red( 'skipping _Float16 search' ) )
# end

#-------------------------------------------------------------------------------
def cublas_library():
'''
Expand Down Expand Up @@ -752,6 +770,7 @@ def gpu_blas():
try:
cublas_library()
gpu_blas_found = True
environ.merge( {'gpu_backend' : 'cuda' } )
except Error as ex:
if (gpu_backend == 'cuda'):
raise ex # fatal
Expand All @@ -763,6 +782,7 @@ def gpu_blas():
try:
rocblas_library()
gpu_blas_found = True
environ.merge( {'gpu_backend' : 'hip' } )
except Error as ex:
if (gpu_backend in ('hip', 'rocm')):
raise ex # fatal
Expand All @@ -773,6 +793,7 @@ def gpu_blas():
if (not gpu_blas_found and test_sycl):
try:
sycl_onemkl_library()
environ.merge( {'gpu_backend' : 'sycl' } )
gpu_blas_found = True
except Error as ex:
if (gpu_backend == 'sycl'):
Expand Down
19 changes: 19 additions & 0 deletions config/return_float16.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) 2017-2022, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include <stdio.h>

#include "config.h"

//------------------------------------------------------------------------------
int main()
{
_Float16 a = 0.1;
_Float16 b = 0.2;
_Float16 c = a + b;

printf( "%f + %f = %f -- expected 0.3\n", (float)a, (float)b, (float)c );
return 0;
}
2 changes: 2 additions & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def main():
#config.prog_cxx_flag( '-Wconversion' )
#config.prog_cxx_flag( '-Werror' )

config.float16()

config.openmp()

config.lapack.blas()
Expand Down
12 changes: 12 additions & 0 deletions include/blas/device_blas.hh
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ void swap(
// Level 3 BLAS

//------------------------------------------------------------------------------
void gemm(
blas::Layout layout,
blas::Op transA,
blas::Op transB,
int64_t m, int64_t n, int64_t k,
float16 alpha,
float16 const* A, int64_t lda,
float16 const* B, int64_t ldb,
float16 beta,
float16* C, int64_t ldc,
blas::Queue& queue );

void gemm(
blas::Layout layout,
blas::Op transA,
Expand Down
14 changes: 14 additions & 0 deletions include/blas/fortran.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,20 @@ void BLAS_ztrsv_base(
// =============================================================================
// Level 3 BLAS - Fortran prototypes

#if defined(BLAS_HAVE_MKL)
#include <mkl_types.h>
// -----------------------------------------------------------------------------
#define BLAS_hgemm BLAS_FORTRAN_NAME( hgemm, HGEMM )
void BLAS_hgemm(
char const *transA, char const *transB,
blas_int const *m, blas_int const *n, blas_int const *k,
MKL_F16 const *alpha,
MKL_F16 const *A, blas_int const *lda,
MKL_F16 const *B, blas_int const *ldb,
MKL_F16 const *beta,
MKL_F16 *C, blas_int const *ldc );
#endif

// -----------------------------------------------------------------------------
#define BLAS_sgemm_base BLAS_FORTRAN_NAME( sgemm, SGEMM )
void BLAS_sgemm_base(
Expand Down
Loading