Skip to content

Commit 1020030

Browse files
committed
alright MLbench finalized for real
1 parent f6b0e7d commit 1020030

File tree

8 files changed

+204
-22
lines changed

8 files changed

+204
-22
lines changed

visualbench/benchmark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ def _train_epoch(self, optimizer):
520520

521521
def _test_epoch(self):
522522
assert self._dltest is not None
523+
test_start = time.time()
523524
self.eval()
524525
batch_backup = self.batch
525526

@@ -528,6 +529,8 @@ def _test_epoch(self):
528529
self._one_step(optimizer=None)
529530

530531
self._last_test_time = time.time()
532+
self.log("test time", self._last_test_time - test_start, plot=False)
533+
531534
self._last_test_pass = self.num_passes
532535
self.batch = batch_backup
533536
self.train()

visualbench/logger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def min(self, metric): return np.min(self.list(metric))
2727
def nanmin(self, metric): return np.nanmin(self.list(metric))
2828
def max(self, metric): return np.max(self.list(metric))
2929
def nanmax(self, metric): return np.nanmax(self.list(metric))
30+
def sum(self, metric): return np.sum(self.list(metric))
3031

3132
def interp(self, metric: str) -> np.ndarray:
3233
"""Returns a list of values for a given key, interpolating missing steps."""

visualbench/models/ode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def forward(self, t, z: torch.Tensor):
1616

1717
# test 'dopri5', 'adams'
1818
class NeuralODE(nn.Module):
19-
def __init__(self, in_channels: int, out_channels: int, width: int, act_cls = F.softplus, layer_norm=False, T = 10., steps = 2, adjoint = False, method = 'implicit_adams'):
19+
def __init__(self, in_channels: int, out_channels: int, width: int, act_cls = torch.nn.Softplus, layer_norm=False, T = 10., steps = 2, adjoint = False, method = 'implicit_adams'):
2020
super().__init__()
2121
self.in_layer = nn.Linear(in_channels, width)
2222
self.ode_func = _ODELinear(width, act_cls = act_cls, layer_norm=layer_norm)

visualbench/runs/benchmark_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def quickrun(self):
153153
opt = lambda p, lr: torch.optim.RMSprop(p, lr)
154154
self.run_optimizer(opt, "RMSprop", tune=True, max_dim=None)
155155

156-
opt = lambda p, lr: tz.Optimizer(p, tz.m.SOAP(), tz.m.LR(lr))
156+
opt = lambda p, lr: tz.Optimizer(p, tz.m.SOAP(max_dim=2048), tz.m.LR(lr))
157157
self.run_optimizer(opt, "SOAP", tune=True, max_dim=None)
158158

159159

@@ -198,7 +198,7 @@ def run_stochastic(self):
198198
opt = lambda p, lr: tz.Optimizer(p, tz.m.GGT(), tz.m.LR(lr))
199199
self.run_optimizer(opt, "GGT", tune=True, max_dim=None)
200200

201-
opt = lambda p, lr: tz.Optimizer(p, tz.m.SOAP(), tz.m.LR(lr))
201+
opt = lambda p, lr: tz.Optimizer(p, tz.m.SOAP(max_dim=2048), tz.m.LR(lr))
202202
self.run_optimizer(opt, "SOAP", tune=True, max_dim=None)
203203

204204
# PSGD Kron

visualbench/runs/benchpack.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,54 +82,110 @@ def run_bench(bench: "Benchmark", task_name: str, passes: int, sec: float, metri
8282
if max_dim is not None and dim > max_dim: return
8383

8484
start = time.time()
85+
test_time = 0
8586
clean_mem()
8687

8788
# skip CPU because accelerator state can't change.
8889
if (accelerate) and (Accelerator is not None) and (next(bench.parameters()).is_cuda):
8990
accelerator = Accelerator()
9091
bench = accelerator.prepare(bench)
9192

93+
# -------------------------------- logger func ------------------------------- #
9294
def logger_fn(value: float):
9395
if dim > 100_000: clean_mem()
9496

97+
# set seed
9598
torch.manual_seed(0)
9699
np.random.seed(0)
97100
random.seed(0)
98101

102+
# run
99103
bench.reset().set_performance_mode().set_print_inverval(None)
100104
opt = init_fn(opt_fn, bench, value)
101105
bench.run(opt, max_passes=passes, max_seconds=sec, test_every_forwards=test_every, num_extra_passes=num_extra_passes, step_callbacks=step_callbacks)
106+
107+
# print progress
102108
if print_progress and bench.seconds_passed is not None and bench.seconds_passed > sec:
103109
print(f"{sweep_name}: '{task_name}' timeout, {bench.seconds_passed} > {sec}!")
110+
111+
# add test time
112+
if "test time" in bench.logger:
113+
nonlocal test_time
114+
test_time += bench.logger.sum("test time")
115+
104116
return bench.logger
105117

118+
# --------------------------------- single run ------------------------------- #
106119
if (hyperparam is None) or (not tune):
107-
sweep = single_run(logger_fn, metrics=metrics, fixed_hyperparams=fixed_hyperparams, root=root, task_name=task_name, run_name=sweep_name, print_records=print_records, print_progress=print_progress, save=save, load_existing=load_existing)
108-
120+
sweep = single_run(
121+
logger_fn,
122+
metrics=metrics,
123+
fixed_hyperparams=fixed_hyperparams,
124+
root=root,
125+
task_name=task_name,
126+
run_name=sweep_name,
127+
print_records=print_records,
128+
print_progress=print_progress,
129+
save=save,
130+
load_existing=load_existing,
131+
)
132+
133+
# -------------------------------- mbs search -------------------------------- #
109134
else:
110-
sweep = mbs_search(logger_fn, metrics=metrics, search_hyperparam=hyperparam, fixed_hyperparams=fixed_hyperparams, log_scale=log_scale, grid=grid, step=step, num_candidates=num_candidates, num_binary=max(1, int(num_binary*binary_mul)), num_expansions=num_expansions, rounding=rounding, root=root, task_name=task_name, run_name=sweep_name, print_records=print_records, save=save, load_existing=load_existing, print_progress=print_progress)
111-
112-
# render video
135+
sweep = mbs_search(
136+
logger_fn,
137+
metrics=metrics,
138+
search_hyperparam=hyperparam,
139+
fixed_hyperparams=fixed_hyperparams,
140+
log_scale=log_scale,
141+
grid=grid,
142+
step=step,
143+
num_candidates=num_candidates,
144+
num_binary=max(1, int(num_binary * binary_mul)),
145+
num_expansions=num_expansions,
146+
rounding=rounding,
147+
root=root,
148+
task_name=task_name,
149+
run_name=sweep_name,
150+
print_records=print_records,
151+
save=save,
152+
load_existing=load_existing,
153+
print_progress=print_progress,
154+
)
155+
156+
# ------------------------------- render video ------------------------------- #
113157
if (render_vids) and (vid_scale is not None) and (self.summaries_root is not None):
114158
assert self.summary_dir is not None
115159
for metric, maximize in _target_metrics_to_dict(metrics).items():
160+
161+
# check if video already exists and skip if it does
116162
video_path = os.path.join(self.summary_dir, f'{task_name} - {metric}')
117163
if os.path.exists(f'{video_path}.mp4'): continue
118164

165+
# find hyperparameter value of the best run
119166
best_run = sweep.best_runs(metric, maximize, 1)[0]
120167
value = 0
121168
if tune and hyperparam is not None: value = best_run.hyperparams[hyperparam]
169+
170+
# run benchmark with visualization enabled
122171
bench.reset().set_performance_mode(False).set_print_inverval(None)
123172
opt = init_fn(opt_fn, bench, value)
124173
bench.run(opt, max_passes=passes, max_seconds=sec, test_every_forwards=test_every, num_extra_passes=num_extra_passes)
174+
175+
# make dirs and render to __TEMP__.mp4 to avoid saving partial renders
125176
if not os.path.exists(self.summaries_root): os.mkdir(self.summaries_root)
126177
if not os.path.exists(self.summary_dir): os.mkdir(self.summary_dir)
127178
bench.render(f'{video_path} __TEMP__', scale=vid_scale, fps=fps, progress=False)
179+
180+
# after successful render renamed __TEMP__.mp4 to actual path
128181
os.rename(f'{video_path} __TEMP__.mp4', f'{video_path}.mp4')
129182

183+
# -------------------------------- print time -------------------------------- #
130184
if print_time:
131-
if print_progress: print(" ", end="\r")
132-
print(f"{task_name} took {(time.time() - start):.2f} s.")
185+
if print_progress: print(" " * 1000, end="\r")
186+
s = f"{task_name} took {(time.time() - start):.2f} s."
187+
if test_time != 0: s = f"{s}; test epochs took {float(test_time):.2f} s."
188+
print(s)
133189

134190
self.run_bench = run_bench
135191

visualbench/runs/mlbench.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Callable, Iterable, Mapping, Sequence
44
from typing import TYPE_CHECKING, Any
55

6+
import rtdl_revisiting_models
67
import torch
78
from monai.losses.dice import DiceFocalLoss
89
from torch import nn
@@ -70,7 +71,7 @@ def run_ml(self):
7071
# # ndim = 132,611
7172
# # 22s. ~ 7m. 20s.
7273
# # 9+3=12 ~ 4m. 20s.
73-
# bench = tasks.WavePINN(tasks.WavePINN.FLS(2, 1, hidden_size=256, n_hidden=3)).to(CUDA_IF_AVAILABLE)
74+
# bench = tasks.WavePINN(tasks.WavePINN.FLS(2, 1, hidden_size=256, n_hidden=3), n_pde=512, n_ic=256, n_bc=256).to(CUDA_IF_AVAILABLE)
7475
# self.run_bench(bench, 'ML - Wave PDE - FLS', passes=10_000, sec=600, metrics='train loss', vid_scale=4)
7576

7677
def run_mls(self):
@@ -81,7 +82,7 @@ def run_mls(self):
8182
# 5s. ~ 1m. 40s.
8283
bench = tasks.Collinear(models.MLP([32, 10]), batch_size=1).to(CUDA_IF_AVAILABLE)
8384
bench_name = 'MLS - Ill-conditioned logistic regression BS-1'
84-
self.run_bench(bench, bench_name, passes=10_000, sec=600, test_every=50, metrics='test loss', vid_scale=None)
85+
self.run_bench(bench, bench_name, passes=20_000, sec=1_000, test_every=50, metrics='test loss', vid_scale=None)
8586

8687
# --------------------------- Matrix factorization --------------------------- #
8788
# ...
@@ -90,32 +91,46 @@ def run_mls(self):
9091
path = "MovieLens-100k/ml-100k"
9192
if not os.path.exists(path):
9293
path = load_movie_lens()
93-
bench = tasks.MFMovieLens(path, batch_size=32, device='cuda').cuda()
94+
bench = tasks.MFMovieLens(path, batch_size=32, device='cuda').to(CUDA_IF_AVAILABLE)
9495
bench_name = 'MLS - MovieLens BS-32 - Matrix Factorization'
95-
self.run_bench(bench, bench_name, passes=10_000, sec=600, test_every=50, metrics='test loss', vid_scale=None)
96+
self.run_bench(bench, bench_name, passes=20_000, sec=1_000, test_every=50, metrics='test loss', vid_scale=None)
9697

9798
# ------------------------------ MLP (Colinear) ------------------------------ #
9899
model = models.MLP([32, 64, 96, 128, 256, 10])
99-
bench = tasks.Collinear(model, batch_size=64, test_batch_size=4096).cuda()
100+
bench = tasks.Collinear(model, batch_size=64, test_batch_size=4096).to(CUDA_IF_AVAILABLE)
100101
bench_name = 'MLS - Colinear BS-64 - MLP(32-64-96-128-256-10)'
101-
self.run_bench(bench, bench_name, passes=10_000, sec=600, test_every=100, metrics='test loss', vid_scale=None)
102+
self.run_bench(bench, bench_name, passes=20_000, sec=1_000, test_every=100, metrics='test loss', vid_scale=None)
102103

103104
# ------------------------------- RNN (MNIST-1D) ------------------------------ #
104105
# ndim = 20,410
105106
# 11s. ~ 3m. 30s.
106-
bench = tasks.datasets.Mnist1d(
107+
bench = tasks.Mnist1d(
107108
models.RNN(1, 10, hidden_size=40, num_layers=2, rnn=torch.nn.RNN),
108109
batch_size=128,
109110
).to(CUDA_IF_AVAILABLE)
110-
bench_name = 'MLS - MNIST-1D BS-128 - RNN(2x40)'
111-
self.run_bench(bench, bench_name, passes=10_000, sec=600, test_every=20, metrics='test loss', vid_scale=None, binary_mul=0.5)
111+
bench_name = 'MLS - Mnist1d-5_000 BS-128 - RNN(2x40)'
112+
self.run_bench(bench, bench_name, passes=20_000, sec=1_000, test_every=20, metrics='test loss', vid_scale=None)
113+
114+
# ------------------------- FTTransformer (MNIST-1D) ------------------------- #
115+
class NoCat(torch.nn.Module):
116+
def __init__(self):
117+
super().__init__()
118+
self.model = rtdl_revisiting_models.FTTransformer(n_cont_features=40, cat_cardinalities=[], d_out=10,
119+
**rtdl_revisiting_models.FTTransformer.get_default_kwargs(1))
120+
121+
def forward(self, x):
122+
return self.model.forward(x, None)
123+
124+
bench = tasks.Mnist1d(NoCat(), batch_size=32, test_batch_size=1024, num_samples=20_000).to(CUDA_IF_AVAILABLE)
125+
bench_name = 'MLS - Mnist1d-20_000 BS-32 - FTTransformer'
126+
self.run_bench(bench, bench_name, passes=20_000, sec=1_000, test_every=200, metrics='test loss', vid_scale=None)
112127

113128
# ---------------------------- ConvNet (MNIST-1D) ---------------------------- #
114129
# ndim = 134,410
115-
bench = tasks.datasets.Mnist1d(
130+
bench = tasks.Mnist1d(
116131
models.vision.ConvNet(40, 1, 10, widths=(64, 128, 256), dropout=0.7),
117132
batch_size=32, test_batch_size=256
118133
).to(CUDA_IF_AVAILABLE)
119-
bench_name = "MLS - MNIST-1D BS-32 - ConvNet"
120-
self.run_bench(bench, bench_name, passes=20_000, sec=1000, test_every=50, metrics = "test loss", vid_scale=None)
134+
bench_name = "MLS - Mnist1d-5_000 BS-32 - ConvNet"
135+
self.run_bench(bench, bench_name, passes=20_000, sec=1_000, test_every=50, metrics = "test loss", vid_scale=None)
121136

visualbench/tasks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
test_functions,
2525
TEST_FUNCTIONS,
2626
)
27+
from .tammes import Tammes
2728
from .glimmer import Glimmer
2829
from .gmm import GaussianMixtureNLL
2930
from .graph_layout import GraphLayout

visualbench/tasks/tammes.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import math
2+
3+
import cv2
4+
import numpy as np
5+
import torch
6+
from torch import nn
7+
from ..benchmark import Benchmark
8+
9+
class Tammes(Benchmark):
10+
"""Tammes problem is to maximize minimal distance between points on a sphere
11+
12+
Points are parameterized by spherical coordinates (theta, phi).
13+
14+
Renders:
15+
points.
16+
17+
Args:
18+
num_points (int): The number of points (N) on the sphere.
19+
initial_dist_epsilon (float): Small value to perturb initial positions
20+
to avoid stacking points and poles.
21+
"""
22+
def __init__(self, num_points: int, initial_dist_epsilon: float = 1e-3, resolution=256, p=2, draw_lines=None):
23+
super().__init__()
24+
if num_points < 2:
25+
raise ValueError("Number of points must be at least 2.")
26+
self.num_points = num_points
27+
28+
self.p=p
29+
initial_thetas = torch.rand(num_points) * (math.pi - 2 * initial_dist_epsilon) + initial_dist_epsilon
30+
initial_phis = torch.rand(num_points) * (2 * math.pi)
31+
32+
self.thetas = nn.Parameter(initial_thetas)
33+
self.phis = nn.Parameter(initial_phis)
34+
35+
self.eps = 1e-12
36+
37+
if draw_lines is None: draw_lines = num_points < 12
38+
self.draw_lines = draw_lines
39+
self.resolution = resolution
40+
41+
def spherical_to_cartesian(self, thetas: torch.Tensor, phis: torch.Tensor) -> torch.Tensor:
42+
"""Converts spherical coordinates (unit radius) to Cartesian coordinates."""
43+
x = torch.sin(thetas) * torch.cos(phis)
44+
y = torch.sin(thetas) * torch.sin(phis)
45+
z = torch.cos(thetas)
46+
# (num_points, 3)
47+
coords = torch.stack([x, y, z], dim=1)
48+
return coords
49+
50+
@torch.no_grad
51+
def _make_frame(
52+
self,
53+
coords: torch.Tensor,
54+
img_size: int = 512,
55+
point_radius: int = 5,
56+
draw_lines: bool = False,
57+
line_thickness: int = 1,
58+
line_color: tuple[int, int, int] = (70, 70, 70) # Faint grey BGR
59+
) -> np.ndarray:
60+
frame = np.zeros((img_size, img_size, 3), dtype=np.uint8)
61+
cv2.circle(frame, (img_size // 2, img_size // 2), img_size // 2 - 1, (50, 50, 50), 1, cv2.LINE_AA) # pylint:disable=no-member
62+
63+
coords_np = coords.detach().cpu().numpy()
64+
65+
# Project onto xy plane and scale to image coordinates
66+
# x, y are in [-1, 1], map to [0, img_size]
67+
img_coords = []
68+
for i in range(self.num_points):
69+
x, y, z = coords_np[i]
70+
# Scale x, y from [-1, 1] to [0, img_size]
71+
img_x = int((x + 1.0) / 2.0 * img_size)
72+
img_y = int((y + 1.0) / 2.0 * img_size)
73+
img_coords.append(((img_x, img_y), z))
74+
75+
# Draw lines firstso points are drawn on top
76+
if draw_lines:
77+
for i in range(self.num_points):
78+
pt1, _ = img_coords[i]
79+
for j in range(i + 1, self.num_points):
80+
pt2, _ = img_coords[j]
81+
cv2.line(frame, pt1, pt2, line_color, line_thickness, cv2.LINE_AA) # pylint:disable=no-member
82+
83+
# Points (circles)
84+
for i in range(self.num_points):
85+
(img_x, img_y), z = img_coords[i]
86+
87+
# Color/brightness to indicate depth (z coordinate)
88+
intensity = int((z + 1.0) / 2.0 * 200) + 55 # Map z=[-1,1] to brightness [55, 255]
89+
color = (intensity // 2, intensity // 2, intensity) # BGR, bias towards blue/white
90+
91+
cv2.circle(frame, (img_x, img_y), point_radius, color, -1, cv2.LINE_AA) # filled circle # pylint:disable=no-member
92+
93+
return frame
94+
95+
def get_loss(self) -> torch.Tensor:
96+
cartesian = self.spherical_to_cartesian(self.thetas, self.phis)
97+
pdists = torch.cdist(cartesian, cartesian, p=self.p)
98+
pdists = pdists + torch.eye(pdists.size(0), device=pdists.device, dtype=pdists.dtype) * pdists.amax().detach() * 2
99+
100+
loss = 1 / pdists.amin()
101+
102+
if self._make_images:
103+
frame = self._make_frame(cartesian, img_size=self.resolution, draw_lines=self.draw_lines)
104+
self.log_image('solution', frame, to_uint8=False, show_best=True)
105+
106+
return loss

0 commit comments

Comments
 (0)