Skip to content

Commit 92d5eb8

Browse files
committed
Merge remote-tracking branch 'ngxson/xsn/emscripten_webgpu' into wasm
2 parents 9726640 + bf9d14c commit 92d5eb8

File tree

8 files changed

+213
-37
lines changed

8 files changed

+213
-37
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,5 @@ poetry.toml
152152
# IDE
153153
*.code-workspace
154154
.windsurf/
155+
# emscripten
156+
a.out.*

CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF)
3636
if (EMSCRIPTEN)
3737
set(BUILD_SHARED_LIBS_DEFAULT OFF)
3838

39-
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
39+
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF)
40+
option(LLAMA_BUILD_HTML "llama: build HTML file" ON)
41+
if (LLAMA_BUILD_HTML)
42+
set(CMAKE_EXECUTABLE_SUFFIX ".html")
43+
endif()
4044
else()
4145
if (MINGW)
4246
set(BUILD_SHARED_LIBS_DEFAULT OFF)

common/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,8 @@ std::string fs_get_cache_directory() {
889889
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
890890
#elif defined(_WIN32)
891891
cache_directory = std::getenv("LOCALAPPDATA");
892+
#elif defined(__EMSCRIPTEN__)
893+
GGML_ABORT("not implemented on this platform");
892894
#else
893895
# error Unknown architecture
894896
#endif

ggml/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ option(GGML_WEBGPU "ggml: use WebGPU"
224224
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
225225
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
226226
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
227-
227+
option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
228228
option(GGML_ZDNN "ggml: use zDNN" OFF)
229229
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
230230
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)

ggml/src/ggml-webgpu/CMakeLists.txt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,35 @@ add_dependencies(ggml-webgpu generate_shaders)
3939
if(EMSCRIPTEN)
4040
set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg")
4141

42-
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
43-
target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
42+
if(NOT EMDAWNWEBGPU_DIR)
43+
# default built-in port
44+
target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu")
45+
target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu")
46+
else()
47+
# custom port
48+
target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
49+
target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py")
50+
endif()
51+
52+
if (GGML_WEBGPU_JSPI)
53+
target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions")
54+
target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions")
55+
else()
56+
target_compile_options(ggml-webgpu PRIVATE "-fexceptions")
57+
target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions")
58+
endif()
59+
60+
set(DawnWebGPU_TARGET webgpu_cpp)
4461
else()
4562
find_package(Dawn REQUIRED)
4663
set(DawnWebGPU_TARGET dawn::webgpu_dawn)
4764
endif()
4865

4966
if (GGML_WEBGPU_DEBUG)
5067
target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1)
68+
if(EMSCRIPTEN)
69+
target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2")
70+
endif()
5171
endif()
5272

5373
if (GGML_WEBGPU_CPU_PROFILE)

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#include "ggml-impl.h"
1010
#include "ggml-wgsl-shaders.hpp"
1111

12+
#ifdef __EMSCRIPTEN__
13+
# include <emscripten/emscripten.h>
14+
#endif
15+
1216
#include <webgpu/webgpu_cpp.h>
1317

1418
#include <atomic>
@@ -2382,13 +2386,17 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
23822386

23832387
webgpu_context ctx = reg_ctx->webgpu_ctx;
23842388

2389+
wgpu::RequestAdapterOptions options = {};
2390+
2391+
#ifndef __EMSCRIPTEN__
23852392
// TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
23862393
const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
23872394
wgpu::DawnTogglesDescriptor adapterTogglesDesc;
23882395
adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
23892396
adapterTogglesDesc.enabledToggleCount = 2;
2390-
wgpu::RequestAdapterOptions options = {};
23912397
options.nextInChain = &adapterTogglesDesc;
2398+
#endif
2399+
23922400
ctx->instance.WaitAny(ctx->instance.RequestAdapter(
23932401
&options, wgpu::CallbackMode::AllowSpontaneous,
23942402
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
@@ -2438,7 +2446,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
24382446

24392447
// Initialize device
24402448
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
2441-
wgpu::FeatureName::ImplicitDeviceSynchronization };
2449+
#ifndef __EMSCRIPTEN__
2450+
wgpu::FeatureName::ImplicitDeviceSynchronization
2451+
#endif
2452+
};
2453+
24422454
if (ctx->supports_subgroup_matrix) {
24432455
required_features.push_back(wgpu::FeatureName::Subgroups);
24442456
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
@@ -2448,19 +2460,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
24482460
required_features.push_back(wgpu::FeatureName::TimestampQuery);
24492461
#endif
24502462

2451-
// Enable Dawn-specific toggles to increase native performance
2452-
// TODO: Don't enable for WASM builds, they won't have an effect anyways
2453-
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2454-
// only for native performance?
2455-
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2456-
"disable_polyfills_on_integer_div_and_mod" };
2457-
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2458-
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2459-
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2460-
deviceTogglesDesc.enabledToggleCount = 4;
2461-
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2462-
deviceTogglesDesc.disabledToggleCount = 1;
2463-
24642463
wgpu::DeviceDescriptor dev_desc;
24652464
dev_desc.requiredLimits = &ctx->limits;
24662465
dev_desc.requiredFeatures = required_features.data();
@@ -2478,7 +2477,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
24782477
GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
24792478
std::string(message).c_str());
24802479
});
2480+
2481+
#ifndef __EMSCRIPTEN__
2482+
// Enable Dawn-specific toggles to increase native performance
2483+
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
2484+
// only for native performance?
2485+
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2486+
"disable_polyfills_on_integer_div_and_mod" };
2487+
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2488+
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
2489+
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
2490+
deviceTogglesDesc.enabledToggleCount = 4;
2491+
deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
2492+
deviceTogglesDesc.disabledToggleCount = 1;
2493+
24812494
dev_desc.nextInChain = &deviceTogglesDesc;
2495+
#endif
2496+
24822497
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
24832498
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
24842499
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
@@ -2576,18 +2591,27 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
25762591
ctx.name = GGML_WEBGPU_NAME;
25772592
ctx.device_count = 1;
25782593

2579-
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
2580-
2581-
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
2582-
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
2583-
instanceTogglesDesc.enabledToggleCount = 1;
25842594
wgpu::InstanceDescriptor instance_descriptor{};
25852595
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
25862596
instance_descriptor.requiredFeatures = instance_features.data();
25872597
instance_descriptor.requiredFeatureCount = instance_features.size();
2588-
instance_descriptor.nextInChain = &instanceTogglesDesc;
2598+
2599+
#ifndef __EMSCRIPTEN__
2600+
const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
2601+
wgpu::DawnTogglesDescriptor instanceTogglesDesc;
2602+
instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
2603+
instanceTogglesDesc.enabledToggleCount = 1;
2604+
instance_descriptor.nextInChain = &instanceTogglesDesc;
2605+
#endif
25892606

25902607
webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
2608+
2609+
#ifdef __EMSCRIPTEN__
2610+
if (webgpu_ctx->instance == nullptr) {
2611+
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
2612+
return nullptr;
2613+
}
2614+
#endif
25912615
GGML_ASSERT(webgpu_ctx->instance != nullptr);
25922616

25932617
static ggml_backend_reg reg = {

scripts/serve-static.js

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
const http = require('http');
2+
const fs = require('fs').promises;
3+
const path = require('path');
4+
5+
// This file is used for testing wasm build from emscripten
6+
// Example build command:
7+
// emcmake cmake -B build-wasm -DGGML_WEBGPU=ON -DLLAMA_CURL=OFF
8+
// cmake --build build-wasm --target test-backend-ops -j
9+
10+
const PORT = 8080;
11+
const STATIC_DIR = path.join(__dirname, '../build-wasm/bin');
12+
console.log(`Serving static files from: ${STATIC_DIR}`);
13+
14+
const mimeTypes = {
15+
'.html': 'text/html',
16+
'.js': 'text/javascript',
17+
'.css': 'text/css',
18+
'.png': 'image/png',
19+
'.jpg': 'image/jpeg',
20+
'.gif': 'image/gif',
21+
'.svg': 'image/svg+xml',
22+
'.json': 'application/json',
23+
'.woff': 'font/woff',
24+
'.woff2': 'font/woff2',
25+
};
26+
27+
async function generateDirListing(dirPath, reqUrl) {
28+
const files = await fs.readdir(dirPath);
29+
let html = `
30+
<!DOCTYPE html>
31+
<html>
32+
<head>
33+
<title>Directory Listing</title>
34+
<style>
35+
body { font-family: Arial, sans-serif; padding: 20px; }
36+
ul { list-style: none; padding: 0; }
37+
li { margin: 5px 0; }
38+
a { text-decoration: none; color: #0066cc; }
39+
a:hover { text-decoration: underline; }
40+
</style>
41+
</head>
42+
<body>
43+
<h1>Directory: ${reqUrl}</h1>
44+
<ul>
45+
`;
46+
47+
if (reqUrl !== '/') {
48+
html += `<li><a href="../">../ (Parent Directory)</a></li>`;
49+
}
50+
51+
for (const file of files) {
52+
const filePath = path.join(dirPath, file);
53+
const stats = await fs.stat(filePath);
54+
const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : '');
55+
html += `<li><a href="${link}">${file}${stats.isDirectory() ? '/' : ''}</a></li>`;
56+
}
57+
58+
html += `
59+
</ul>
60+
</body>
61+
</html>
62+
`;
63+
return html;
64+
}
65+
66+
const server = http.createServer(async (req, res) => {
67+
try {
68+
// Set COOP and COEP headers
69+
res.setHeader('Cross-Origin-Opener-Policy', 'same-origin');
70+
res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp');
71+
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate');
72+
res.setHeader('Pragma', 'no-cache');
73+
res.setHeader('Expires', '0');
74+
75+
const filePath = path.join(STATIC_DIR, decodeURIComponent(req.url));
76+
const stats = await fs.stat(filePath);
77+
78+
if (stats.isDirectory()) {
79+
const indexPath = path.join(filePath, 'index.html');
80+
try {
81+
const indexData = await fs.readFile(indexPath);
82+
res.writeHeader(200, { 'Content-Type': 'text/html' });
83+
res.end(indexData);
84+
} catch {
85+
// No index.html, generate directory listing
86+
const dirListing = await generateDirListing(filePath, req.url);
87+
res.writeHeader(200, { 'Content-Type': 'text/html' });
88+
res.end(dirListing);
89+
}
90+
} else {
91+
const ext = path.extname(filePath).toLowerCase();
92+
const contentType = mimeTypes[ext] || 'application/octet-stream';
93+
const data = await fs.readFile(filePath);
94+
res.writeHeader(200, { 'Content-Type': contentType });
95+
res.end(data);
96+
}
97+
} catch (err) {
98+
if (err.code === 'ENOENT') {
99+
res.writeHeader(404, { 'Content-Type': 'text/plain' });
100+
res.end('404 Not Found');
101+
} else {
102+
res.writeHeader(500, { 'Content-Type': 'text/plain' });
103+
res.end('500 Internal Server Error');
104+
}
105+
}
106+
});
107+
108+
server.listen(PORT, () => {
109+
console.log(`Server running at http://localhost:${PORT}/`);
110+
});

tests/test-backend-ops.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <ggml-alloc.h>
2020
#include <ggml-backend.h>
2121
#include <ggml-cpp.h>
22+
#include <ggml-cpu.h>
2223

2324
#include <algorithm>
2425
#include <array>
@@ -40,12 +41,18 @@
4041
#include <thread>
4142
#include <vector>
4243

44+
#ifdef __EMSCRIPTEN__
45+
# define N_THREADS 1
46+
#else
47+
# define N_THREADS std::thread::hardware_concurrency()
48+
#endif
49+
4350
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
4451
size_t nels = ggml_nelements(tensor);
4552
std::vector<float> data(nels);
4653
{
4754
// parallel initialization
48-
static const size_t n_threads = std::thread::hardware_concurrency();
55+
static const size_t n_threads = N_THREADS;
4956
// static RNG initialization (revisit if n_threads stops being constant)
5057
static std::vector<std::default_random_engine> generators = []() {
5158
std::random_device rd;
@@ -64,15 +71,19 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
6471
}
6572
};
6673

67-
std::vector<std::future<void>> tasks;
68-
tasks.reserve(n_threads);
69-
for (size_t i = 0; i < n_threads; i++) {
70-
size_t start = i*nels/n_threads;
71-
size_t end = (i+1)*nels/n_threads;
72-
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
73-
}
74-
for (auto & t : tasks) {
75-
t.get();
74+
if (n_threads == 1) {
75+
init_thread(0, 0, nels);
76+
} else {
77+
std::vector<std::future<void>> tasks;
78+
tasks.reserve(n_threads);
79+
for (size_t i = 0; i < n_threads; i++) {
80+
size_t start = i*nels/n_threads;
81+
size_t end = (i+1)*nels/n_threads;
82+
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
83+
}
84+
for (auto & t : tasks) {
85+
t.get();
86+
}
7687
}
7788
}
7889

@@ -104,7 +115,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
104115
};
105116

106117
const size_t min_blocks_per_thread = 1;
107-
const size_t n_threads = std::min<size_t>(std::thread::hardware_concurrency()/2,
118+
const size_t n_threads = std::min<size_t>(N_THREADS/2,
108119
std::max<size_t>(1, n_blocks / min_blocks_per_thread));
109120
std::vector<std::future<void>> tasks;
110121
tasks.reserve(n_threads);
@@ -7379,6 +7390,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
73797390
return false;
73807391
}
73817392

7393+
// TODO: find a better way to set the number of threads for the CPU backend
7394+
ggml_backend_cpu_set_n_threads(backend_cpu, N_THREADS);
7395+
73827396
size_t n_ok = 0;
73837397
size_t tests_run = 0;
73847398
std::vector<std::string> failed_tests;
@@ -7639,7 +7653,7 @@ int main(int argc, char ** argv) {
76397653
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
76407654
if (ggml_backend_set_n_threads_fn) {
76417655
// TODO: better value for n_threads
7642-
ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
7656+
ggml_backend_set_n_threads_fn(backend, N_THREADS);
76437657
}
76447658

76457659
size_t free, total; // NOLINT

0 commit comments

Comments
 (0)