Skip to content

Commit d599fd7

Browse files
authored
adapted pretrained model to training (#371)
1 parent cb90d06 commit d599fd7

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

opensora/models/stdit/stdit2.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import torch
33
import torch.nn as nn
4+
import os
45
from einops import rearrange
56
from rotary_embedding_torch import RotaryEmbedding
67
from timm.models.layers import DropPath
@@ -23,6 +24,7 @@
2324
)
2425
from opensora.registry import MODELS
2526
from transformers import PretrainedConfig, PreTrainedModel
27+
from opensora.utils.ckpt_utils import load_checkpoint
2628

2729

2830
class STDiT2Block(nn.Module):
@@ -502,8 +504,22 @@ def _basic_init(module):
502504
@MODELS.register_module("STDiT2-XL/2")
503505
def STDiT2_XL_2(from_pretrained=None, **kwargs):
504506
if from_pretrained is not None:
505-
model = STDiT2.from_pretrained(from_pretrained, **kwargs)
507+
if os.path.isdir(from_pretrained) or os.path.isfile(from_pretrained):
508+
# if it is a directory or a file, we load the checkpoint manually
509+
config = STDiT2Config(
510+
depth=28,
511+
hidden_size=1152,
512+
patch_size=(1, 2, 2),
513+
num_heads=16, **kwargs
514+
)
515+
model = STDiT2(config)
516+
load_checkpoint(model, from_pretrained)
517+
return model
518+
else:
519+
# otherwise, we load the model from hugging face hub
520+
return STDiT2.from_pretrained(from_pretrained)
506521
else:
522+
# create a new model
507523
config = STDiT2Config(
508524
depth=28,
509525
hidden_size=1152,

scripts/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ def main():
133133
input_size=latent_size,
134134
in_channels=vae.out_channels,
135135
caption_channels=text_encoder.output_dim,
136-
model_max_length=text_encoder.model_max_length,
137-
dtype=dtype,
136+
model_max_length=text_encoder.model_max_length
138137
)
139138
model_numel, model_numel_trainable = get_model_numel(model)
140139
logger.info(

0 commit comments

Comments
 (0)