Skip to content

Commit d5162d4

Browse files
Wojtek KowalukWojtek Kowaluk
authored andcommitted
Fixes to run on CPU and MPS
1 parent a4354c0 commit d5162d4

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

kandinsky2/configs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
"model_dim": 768,
137137
"use_scale_shift_norm": True,
138138
"resblock_updown": True,
139-
"use_fp16": True,
139+
"use_fp16": False,
140140
"cache_text_emb": True,
141141
"text_encoder_in_dim1": 1024,
142142
"text_encoder_in_dim2": 768,

kandinsky2/kandinsky2_1_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
clip_mean,
5555
clip_std,
5656
)
57-
self.prior.load_state_dict(torch.load(prior_path), strict=False)
57+
self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False)
5858
if self.use_fp16:
5959
self.prior = self.prior.half()
6060
self.text_encoder = TextEncoder(**self.config["text_enc_params"])
@@ -88,7 +88,7 @@ def __init__(
8888

8989
self.config["model_config"]["cache_text_emb"] = True
9090
self.model = create_model(**self.config["model_config"])
91-
self.model.load_state_dict(torch.load(model_path))
91+
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
9292
if self.use_fp16:
9393
self.model.convert_to_fp16()
9494
self.image_encoder = self.image_encoder.half()

kandinsky2/model/gaussian_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
822822
dimension equal to the length of timesteps.
823823
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
824824
"""
825-
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
825+
res = th.from_numpy(arr).to(dtype=th.float32).to(device=timesteps.device)[timesteps]
826826
while len(res.shape) < len(broadcast_shape):
827827
res = res[..., None]
828828
return res.expand(broadcast_shape)

0 commit comments

Comments
 (0)