Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
db76617
div
kschmid23 Oct 21, 2025
dbf82d4
Merge branch 'main' into debug3
kschmid23 Oct 21, 2025
44d0e76
div
kschmid23 Oct 21, 2025
71bc0fe
div
kschmid23 Oct 21, 2025
bce729a
div
kschmid23 Oct 21, 2025
ed1231c
div
kschmid23 Oct 21, 2025
24bf941
div
kschmid23 Oct 21, 2025
867674d
div
kschmid23 Oct 21, 2025
c8ca524
div
kschmid23 Oct 21, 2025
b3802cc
deb
kschmid23 Oct 21, 2025
89b99cd
deb
kschmid23 Oct 21, 2025
b4f9833
deb
kschmid23 Oct 21, 2025
966420c
deb
kschmid23 Oct 21, 2025
283b96a
deb
kschmid23 Oct 21, 2025
59ab115
deb
kschmid23 Oct 21, 2025
4e6dbb3
deb
kschmid23 Oct 21, 2025
92de546
deb
kschmid23 Oct 21, 2025
094ecc9
deb
kschmid23 Oct 21, 2025
bb51330
deb
kschmid23 Oct 21, 2025
31bd8ef
deb
kschmid23 Oct 21, 2025
c4993cd
deb
kschmid23 Oct 21, 2025
8c8996f
deb
kschmid23 Oct 21, 2025
967a92c
deb
kschmid23 Oct 21, 2025
e407dcc
deb
kschmid23 Oct 21, 2025
5cf35ae
deb
kschmid23 Oct 21, 2025
ea5d0f2
deb
kschmid23 Oct 21, 2025
f20e76e
deb
kschmid23 Oct 21, 2025
630877c
deb
kschmid23 Oct 21, 2025
10b3d9d
deb
kschmid23 Oct 21, 2025
0cbd0b5
Merge branch 'debug3' of https://github.com/3a1b2c3/ai-toolkit into d…
kschmid23 Oct 21, 2025
bb1eba5
debug
kschmid23 Oct 21, 2025
7a6ccfb
debug
kschmid23 Oct 21, 2025
fb6aee2
debug
kschmid23 Oct 21, 2025
a940602
debug
kschmid23 Oct 21, 2025
f90e651
asssert
kschmid23 Oct 21, 2025
358ebdb
asssert
kschmid23 Oct 21, 2025
54892f4
asssert
kschmid23 Oct 21, 2025
68664dd
asssert
kschmid23 Oct 21, 2025
6a47fe4
Merge branch 'main' into debug3
kschmid23 Oct 22, 2025
c646e14
batch
kschmid23 Oct 22, 2025
6fcc9b8
Merge branch 'main' into debug3
kschmid23 Oct 22, 2025
b19c5d5
netw
kschmid23 Oct 23, 2025
b85d863
Merge branch 'main' into debug3
kschmid23 Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions extensions_built_in/diffusion_models/f_light/src/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# originally from https://github.com/fal-ai/f-lite/blob/main/f_lite/model.py but modified slightly

import math
import numpy as np

import torch
import torch.nn.functional as F
Expand All @@ -15,9 +16,8 @@

def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=t.device
)
freqs = np.exp(-math.log(max_period) * np.arange(start=0, end=half, dtype=np.float32) / half)
freqs = torch.from_numpy(freqs).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

Expand Down
17 changes: 12 additions & 5 deletions extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ def generate_single_image(
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)

num_frames = (
(gen_config.num_frames - 1) // 4
) * 4 + 1 # make sure it is divisible by 4 + 1
# Prevent division by zero for num_frames
if gen_config.num_frames is None or gen_config.num_frames < 1:
num_frames = 1
else:
num_frames = (
(gen_config.num_frames - 1) // 4
) * 4 + 1 # make sure it is divisible by 4 + 1
gen_config.num_frames = num_frames

height = gen_config.height
Expand All @@ -42,10 +46,13 @@ def generate_single_image(
control_img = Image.open(gen_config.ctrl_img).convert("RGB")

d = self.get_bucket_divisibility()
# Prevent division by zero for d
if d is None or d == 0:
d = 1

# make sure they are divisible by d
height = height // d * d
width = width // d * d
height = height // d * d if height is not None else d
width = width // d * d if width is not None else d

# resize the control image
control_img = control_img.resize((width, height), Image.LANCZOS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,11 @@ def save_lora(
output_path: str,
metadata: Optional[Dict[str, Any]] = None,
):
if not self.network.network_config.split_multistage_loras:
# Use .module if network is DDP
network_config = getattr(self.network, "network_config", None)
if hasattr(self.network, "module"):
network_config = getattr(self.network.module, "network_config", None)
if not network_config or not getattr(network_config, "split_multistage_loras", False):
# just save as a combo lora
save_file(state_dict, output_path, metadata=metadata)
return
Expand Down
2 changes: 2 additions & 0 deletions extensions_built_in/sd_trainer/SDTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,7 @@ def predict_noise(
guidance_embedding_scale = self.train_config.cfg_scale
if self.train_config.do_guidance_loss:
guidance_embedding_scale = self._guidance_loss_target_batch
assert batch.tensor is not None, "Batch tensor is None in predict_noise"
return self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
Expand Down Expand Up @@ -2034,6 +2035,7 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD
batch_list = [batch]
total_loss = None
self.optimizer.zero_grad()
assert len(batch_list) > 0, "No batches to process in train loop"
for batch in batch_list:
if self.sd.is_multistage:
# handle multistage switching
Expand Down
65 changes: 43 additions & 22 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No

# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
self.is_fine_tuning = True
if self.network_config is not None or is_training_adapter or self.embed_config is not None or self.decorator_config is not None:
if self.network_config_safe is not None or is_training_adapter or self.embed_config is not None or self.decorator_config is not None:
self.is_fine_tuning = False

self.named_lora = False
Expand Down Expand Up @@ -287,6 +287,8 @@ def sample(self, step=None, is_first=False):
for i in range(len(sample_config.prompts)):
test_image_paths.append(test_image_path_list[i % len(test_image_path_list)])

print_acc("sample", len(sample_config.prompts))

for i in range(len(sample_config.prompts)):
if sample_config.walk_seed:
current_seed = start_seed + i
Expand Down Expand Up @@ -352,28 +354,35 @@ def sample(self, step=None, is_first=False):
do_cfg_norm=sample_config.do_cfg_norm,
**extra_args
))

# post process
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)

# if we have an ema, set it to validation mode
if self.ema is not None:
self.ema.eval()

# let adapter know we are sampling
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
self.adapter.is_sampling = True

# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)

print_acc(len(gen_img_config_list), self.sd.network is not None, "gen_img_config_list2", self.ema,
self.model_config.assistant_lora_path is not None, self.model_config.inference_lora_path is not None)
if gen_img_config_list:
print("[DEBUG] generate_images: after lora handling",gen_img_config_list)
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
print("[DEBUG] generate_images: after lora handling1")

print("gen_img_config_list3", len(gen_img_config_list))
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
self.adapter.is_sampling = False

if self.ema is not None:
self.ema.train()

print("gen_img_config_list4", len(gen_img_config_list))

def update_training_metadata(self):
o_dict = OrderedDict({
"training_info": self.get_training_info()
Expand Down Expand Up @@ -447,10 +456,12 @@ def clean_up_saves(self):
combined_items.sort(key=os.path.getctime)

num_saves_to_keep = self.save_config.max_step_saves_to_keep

if hasattr(self.sd, 'max_step_saves_to_keep_multiplier'):
num_saves_to_keep *= self.sd.max_step_saves_to_keep_multiplier

assert num_saves_to_keep > 0

# Use slicing with a check to avoid 'NoneType' error
safetensors_to_remove = safetensors_files[
:-num_saves_to_keep] if safetensors_files else []
Expand Down Expand Up @@ -1036,7 +1047,7 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
# do random prompt saturation by expanding the prompt to hit at least 77 tokens
if random.random() < self.train_config.prompt_saturation_chance:
est_num_tokens = len(prompt.split(' '))
if est_num_tokens < 77:
if est_num_tokens < 77 and est_num_tokens > 0:
num_repeats = int(77 / est_num_tokens) + 1
prompt = ', '.join([prompt] * num_repeats)

Expand Down Expand Up @@ -1070,6 +1081,7 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
# Standard Deviation: tensor([0.5623, 0.5295, 0.5347])
imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True)
imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True)
assert imgs_channel_std.min() > 0, "Image channel std is too small, cannot standardize"
imgs = (imgs - imgs_channel_mean) / imgs_channel_std
target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
Expand All @@ -1095,6 +1107,7 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):

latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True)
latents_channel_std = latents.std(dim=(2, 3), keepdim=True)
assert latents_channel_std.min() > 0, "Latent channel std is too small, cannot standardize"
latents = (latents - latents_channel_mean) / latents_channel_std
target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
Expand Down Expand Up @@ -1503,6 +1516,13 @@ def setup_adapter(self):
# set trainable params
self.sd.adapter = self.adapter

@property
def network_config_safe(self):
# Use .module if wrapped by DDP, else direct
if hasattr(self, "module") and hasattr(self.module, "network_config"):
return self.module.network_config
return getattr(self, "network_config", None)

def run(self):
# torch.autograd.set_detect_anomaly(True)
# run base process run
Expand Down Expand Up @@ -1690,14 +1710,14 @@ def run(self):
self.hook_after_model_load()
flush()
if not self.is_fine_tuning:
if self.network_config is not None:
if self.network_config_safe is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
network_kwargs = self.network_config.network_kwargs
network_kwargs = self.network_config_safe.network_kwargs
is_lycoris = False
is_lorm = self.network_config.type.lower() == 'lorm'
is_lorm = self.network_config_safe.type.lower() == 'lorm'
# default to LoCON if there are any conv layers or if it is named
NetworkClass = LoRASpecialNetwork
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
if self.network_config_safe.type.lower() == 'locon' or self.network_config_safe.type.lower() == 'lycoris':
NetworkClass = LycorisSpecialNetwork
is_lycoris = True

Expand All @@ -1716,13 +1736,13 @@ def run(self):
self.network = NetworkClass(
text_encoder=text_encoder,
unet=self.sd.get_model_to_train(),
lora_dim=self.network_config.linear,
lora_dim=self.network_config_safe.linear,
multiplier=1.0,
alpha=self.network_config.linear_alpha,
alpha=self.network_config_safe.linear_alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
conv_lora_dim=self.network_config_safe.conv,
conv_alpha=self.network_config_safe.conv_alpha,
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
is_v2=self.model_config.is_v2,
is_v3=self.model_config.is_v3,
Expand All @@ -1732,14 +1752,15 @@ def run(self):
is_lumina2=self.model_config.is_lumina2,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,
dropout=self.network_config_safe.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,
use_bias=is_lorm,

is_lorm=is_lorm,
network_config=self.network_config,
network_type=self.network_config.type,
transformer_only=self.network_config.transformer_only,
network_config=self.network_config_safe,
network_type=self.network_config_safe.type,
transformer_only=self.network_config_safe.transformer_only,
is_transformer=self.sd.is_transformer,
base_model=self.sd,
**network_kwargs
Expand Down Expand Up @@ -2012,7 +2033,7 @@ def run(self):
elif self.step_num <= 1 or self.train_config.force_first_sample:
print_acc("Generating baseline samples before training")
self.sample(self.step_num)

print("Starting training loop...", self.accelerator.is_local_main_process)
if self.accelerator.is_local_main_process:
self.progress_bar = ToolkitProgressBar(
total=self.train_config.steps,
Expand Down Expand Up @@ -2058,7 +2079,7 @@ def run(self):
###################################################################
# TRAIN LOOP
###################################################################

print("Starting training loop... self.train_config.steps", self.train_config.steps)

start_step_num = self.step_num
did_first_flush = False
Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main():

if accelerator.is_main_process:
print_acc(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")

job = None
for config_file in config_file_list:
try:
job = get_job(config_file, args.name)
Expand Down
3 changes: 2 additions & 1 deletion toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random

import torch
import torchaudio


from toolkit.prompt_utils import PromptEmbeds

Expand Down Expand Up @@ -1150,6 +1150,7 @@ def save_image(self, image, count: int = 0, max_count=0):
else:
raise ValueError(f"Unsupported video format {self.output_ext}")
elif self.output_ext in ['wav', 'mp3']:
import torchaudio
# save audio file
torchaudio.save(
self.get_image_path(count, max_count),
Expand Down
10 changes: 5 additions & 5 deletions toolkit/models/i2v_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,11 @@ def __init__(
model_class = sd.model.__class__.__name__

if self.network_config is not None:

network_kwargs = {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs
# Support DDP-wrapped network_config
network_config = self.network_config
if hasattr(network_config, "module"):
network_config = network_config.module
network_kwargs = {} if getattr(network_config, "network_kwargs", None) is None else network_config.network_kwargs
if hasattr(sd, 'target_lora_modules'):
network_kwargs['target_lin_modules'] = sd.target_lora_modules

Expand All @@ -363,9 +366,6 @@ def __init__(
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=self.network_config.conv,
conv_alpha=self.network_config.conv_alpha,
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
is_v2=self.model_config.is_v2,
is_v3=self.model_config.is_v3,
is_pixart=self.model_config.is_pixart,
is_auraflow=self.model_config.is_auraflow,
is_flux=self.model_config.is_flux,
Expand Down
16 changes: 10 additions & 6 deletions toolkit/samplers/custom_flowmatch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,19 @@ def __init__(self, *args, **kwargs):
num_timesteps = 1000
# Bell-Shaped Mean-Normalized Timestep Weighting
# bsmntw? need a better name

x = torch.arange(num_timesteps, dtype=torch.float32)
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)

print(13)
# Guard against division by zero
try:
denom = num_timesteps
y = torch.from_numpy(
np.exp(-2 * ((x.numpy() - num_timesteps / 2) / denom) ** 2)
)
except ZeroDivisionError:
y = torch.ones_like(x)
print(15)
# Shift minimum to 0
y_shifted = y - y.min()

# Scale to make mean 1
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())

Expand Down Expand Up @@ -180,8 +186,6 @@ def set_train_timesteps(
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat(
[sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat(
[sigmas, torch.zeros(1, device=sigmas.device)])
Expand Down
Loading