File tree Expand file tree Collapse file tree 2 files changed +18
-3
lines changed Expand file tree Collapse file tree 2 files changed +18
-3
lines changed Original file line number Diff line number Diff line change 1
1
import numpy as np
2
2
import torch
3
3
import torch .nn as nn
4
+ import os
4
5
from einops import rearrange
5
6
from rotary_embedding_torch import RotaryEmbedding
6
7
from timm .models .layers import DropPath
23
24
)
24
25
from opensora .registry import MODELS
25
26
from transformers import PretrainedConfig , PreTrainedModel
27
+ from opensora .utils .ckpt_utils import load_checkpoint
26
28
27
29
28
30
class STDiT2Block (nn .Module ):
@@ -502,8 +504,22 @@ def _basic_init(module):
502
504
@MODELS .register_module ("STDiT2-XL/2" )
503
505
def STDiT2_XL_2 (from_pretrained = None , ** kwargs ):
504
506
if from_pretrained is not None :
505
- model = STDiT2 .from_pretrained (from_pretrained , ** kwargs )
507
+ if os .path .isdir (from_pretrained ) or os .path .isfile (from_pretrained ):
508
+ # if it is a directory or a file, we load the checkpoint manually
509
+ config = STDiT2Config (
510
+ depth = 28 ,
511
+ hidden_size = 1152 ,
512
+ patch_size = (1 , 2 , 2 ),
513
+ num_heads = 16 , ** kwargs
514
+ )
515
+ model = STDiT2 (config )
516
+ load_checkpoint (model , from_pretrained )
517
+ return model
518
+ else :
519
+ # otherwise, we load the model from hugging face hub
520
+ return STDiT2 .from_pretrained (from_pretrained )
506
521
else :
522
+ # create a new model
507
523
config = STDiT2Config (
508
524
depth = 28 ,
509
525
hidden_size = 1152 ,
Original file line number Diff line number Diff line change @@ -133,8 +133,7 @@ def main():
133
133
input_size = latent_size ,
134
134
in_channels = vae .out_channels ,
135
135
caption_channels = text_encoder .output_dim ,
136
- model_max_length = text_encoder .model_max_length ,
137
- dtype = dtype ,
136
+ model_max_length = text_encoder .model_max_length
138
137
)
139
138
model_numel , model_numel_trainable = get_model_numel (model )
140
139
logger .info (
You can’t perform that action at this time.
0 commit comments