diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 141c1d2f4..ba5e6a87d 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1542,6 +1542,20 @@ image = sana( ) `, ]; +export const videoprism = (model: ModelData): string[] => [ + `# Install from https://github.com/google-deepmind/videoprism +import jax +import jax.numpy as jnp +from videoprism import models as vp + +flax_model = vp.MODELS["${model.id}"]() +loaded_state = vp.load_pretrained_weights("${model.id}") + +@jax.jit +def forward_fn(inputs, train=False): + return flax_model.apply(loaded_state, inputs, train=train)`, +]; + export const vfimamba = (model: ModelData): string[] => [ `from Trainer_finetune import Model diff --git a/packages/tasks/src/model-libraries.ts b/packages/tasks/src/model-libraries.ts index 05f66b1a8..7662413b2 100644 --- a/packages/tasks/src/model-libraries.ts +++ b/packages/tasks/src/model-libraries.ts @@ -1093,6 +1093,13 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = { countDownloads: `path_extension:"pth"`, snippets: snippets.sana, }, + videoprism: { + prettyLabel: "VideoPrism", + repoName: "VideoPrism", + repoUrl: "https://github.com/google-deepmind/videoprism", + countDownloads: `path_extension:"npz"`, + snippets: snippets.videoprism, + }, "vfi-mamba": { prettyLabel: "VFIMamba", repoName: "VFIMamba",