diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 11e410598..6d1cba8d1 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1535,6 +1535,23 @@ image = sana( ) `, ]; +export const videoprism = (model: ModelData): string[] => [ + `# Install from https://github.com/google-deepmind/videoprism +import jax +import jax.numpy as jnp +from videoprism import models as vp + +# Models available: ['videoprism_public_v1_base_hf', 'videoprism_public_v1_large_hf'] +MODEL_NAME = 'videoprism_public_v1_base_hf' + +flax_model = vp.MODELS[MODEL_NAME]() +loaded_state = vp.load_pretrained_weights(MODEL_NAME) + +@jax.jit +def forward_fn(inputs, train=False): + return flax_model.apply(loaded_state, inputs, train=train)`, +]; + export const vfimamba = (model: ModelData): string[] => [ `from Trainer_finetune import Model diff --git a/packages/tasks/src/model-libraries.ts b/packages/tasks/src/model-libraries.ts index 9641ed689..c4342a244 100644 --- a/packages/tasks/src/model-libraries.ts +++ b/packages/tasks/src/model-libraries.ts @@ -1086,6 +1086,13 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = { countDownloads: `path_extension:"pth"`, snippets: snippets.sana, }, + "videoprism": { + prettyLabel: "VideoPrism", + repoName: "VideoPrism", + repoUrl: "https://github.com/google-deepmind/videoprism", + countDownloads: `path_extension:"npz"`, + snippets: snippets.videoprism, + }, "vfi-mamba": { prettyLabel: "VFIMamba", repoName: "VFIMamba",