diff --git a/c_src/zstd_nif.c b/c_src/zstd_nif.c index 0264929..8955581 100644 --- a/c_src/zstd_nif.c +++ b/c_src/zstd_nif.c @@ -1,10 +1,17 @@ #include "erl_nif.h" - #include +#include +#include #include +#define MAX_BYTES_TO_NIF 20000 + ErlNifTSDKey zstdDecompressContextKey; ErlNifTSDKey zstdCompressContextKey; +ErlNifTSDKey zstdCompressToFileContextKey; + +static ERL_NIF_TERM do_compress_to_file(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); + static ERL_NIF_TERM zstd_nif_compress(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { ErlNifBinary bin, ret_bin; @@ -62,15 +69,89 @@ static ERL_NIF_TERM zstd_nif_decompress(ErlNifEnv* env, int argc, const ERL_NIF_ return out; } + + +static int save_file(const char* fileName, const void* buff, size_t buffSize) +{ + FILE* const oFile = fopen(fileName, "a"); + if (!oFile) { + return 0; + } + size_t const wSize = fwrite(buff, 1, buffSize, oFile); + if (wSize != (size_t)buffSize) { + return 0; + } + if (fclose(oFile)) { + return 0; + } + return 1; +} +static ERL_NIF_TERM zstd_nif_compress_to_file(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + ErlNifBinary bin; + if(!enif_inspect_iolist_as_binary(env, argv[0], &bin)) return enif_make_badarg(env); + if (bin.size > MAX_BYTES_TO_NIF) { + return enif_schedule_nif(env, "do_compress_to_file", ERL_NIF_DIRTY_JOB_CPU_BOUND, do_compress_to_file, argc, argv); + } + + return do_compress_to_file(env, argc, argv); +} + +static ERL_NIF_TERM do_compress_to_file(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + ErlNifBinary bin, ret_bin; + size_t buff_size, compressed_size; + unsigned int compression_level, path_len; + + ZSTD_CCtx* ctx = (ZSTD_CCtx*)enif_tsd_get(zstdCompressToFileContextKey); + if (!ctx) { + ctx = ZSTD_createCCtx(); + enif_tsd_set(zstdCompressToFileContextKey, ctx); + } + + enif_get_list_length(env, argv[1], &path_len); + char path[path_len + 1]; + + if(!enif_inspect_iolist_as_binary(env, argv[0], &bin) + || !enif_get_string(env, argv[1], path, (path_len + 1), ERL_NIF_LATIN1) + || !enif_get_uint(env, argv[2], &compression_level) + || compression_level > ZSTD_maxCLevel()) + return enif_make_badarg(env); + + buff_size = ZSTD_compressBound(bin.size); + + if(!enif_alloc_binary(buff_size, &ret_bin)) + return enif_make_atom(env, "error"); + + compressed_size = ZSTD_compressCCtx(ctx, ret_bin.data, buff_size, bin.data, bin.size, compression_level); + if(ZSTD_isError(compressed_size)) { + enif_release_binary(&ret_bin); + return enif_make_atom(env, "error"); + } + + if(!enif_realloc_binary(&ret_bin, compressed_size)) { + enif_release_binary(&ret_bin); + return enif_make_atom(env, "error"); + } + + if (!save_file(path, ret_bin.data, compressed_size)) { + enif_release_binary(&ret_bin); + return enif_make_atom(env, "error"); + } + + enif_release_binary(&ret_bin); + return enif_make_atom(env, "ok"); +} + static ErlNifFunc nif_funcs[] = { {"compress", 2, zstd_nif_compress}, - {"decompress", 1, zstd_nif_decompress} + {"decompress", 1, zstd_nif_decompress}, + {"compress_to_file", 3, zstd_nif_compress_to_file}, }; static int load(ErlNifEnv* env, void** priv_data, ERL_NIF_TERM load_info) { enif_tsd_key_create("zstd_decompress_context_key", &zstdDecompressContextKey); enif_tsd_key_create("zstd_compress_context_key", &zstdCompressContextKey); + enif_tsd_key_create("zstd_compress_to_file_context_key", &zstdCompressToFileContextKey); return 0; } diff --git a/src/zstd.erl b/src/zstd.erl index 5da34b4..efbdd80 100644 --- a/src/zstd.erl +++ b/src/zstd.erl @@ -2,6 +2,7 @@ -export([compress/1, compress/2]). -export([decompress/1]). +-export([compress_to_file/2, compress_to_file/3]). -on_load init/0. @@ -21,6 +22,17 @@ compress(_, _) -> decompress(_) -> erlang:nif_error(?LINE). +-spec compress_to_file(Uncompressed :: iolist(), Path :: string()) -> ok | error. +compress_to_file(IoList, Path) -> + compress_to_file(IoList, Path, 1). + +-spec compress_to_file(Uncompressed :: iolist(), + Path :: string(), + CompressionLevel :: 0..22) -> + ok | error. +compress_to_file(_, _, _) -> + erlang:nif_error(?LINE). + init() -> SoName = case code:priv_dir(?APPNAME) of diff --git a/test/zstd_tests.erl b/test/zstd_tests.erl index 4927a04..1a69011 100644 --- a/test/zstd_tests.erl +++ b/test/zstd_tests.erl @@ -3,7 +3,31 @@ -include_lib("eunit/include/eunit.hrl"). zstd_test() -> - Data = <<"Hello, World!">>, + Data = <<"Hello, World!\n">>, ?assertEqual(Data, zstd:decompress( zstd:compress(Data))). + +compress_to_file_test() -> + Path = "/tmp/zstd_test.zst", + Data = [<<"Hello">>, <<" there!">>], + compress_to_file_and_check(Path, Data). + +compress_to_file_using_dirty_scheduler_test() -> + Path = "/tmp/zstd_dirty_scheduler_test.zst", + Data = + [base64:encode( + crypto:strong_rand_bytes(64000))], + compress_to_file_and_check(Path, Data). + +compress_to_file_and_check(Path, Data) -> + case filelib:is_regular(Path) of + true -> + file:delete(Path); + _ -> + ok + end, + ?assertEqual(ok, zstd:compress_to_file(Data, Path)), + {ok, ToDecompress} = file:read_file(Path), + Decompressed = zstd:decompress(ToDecompress), + ?assertEqual(Decompressed, erlang:iolist_to_binary(Data)).