Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/dify_plugin/core/entities/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ class InvokeType(Enum):
LLM = "llm"
LLMStructuredOutput = "llm_structured_output"
TextEmbedding = "text_embedding"
MultimodalEmbedding = "multimodal_embedding"
Rerank = "rerank"
MultimodalRerank = "multimodal_rerank"
TTS = "tts"
Speech2Text = "speech2text"
Moderation = "moderation"
Expand Down
25 changes: 23 additions & 2 deletions python/dify_plugin/core/entities/plugin/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any

Expand All @@ -9,7 +9,7 @@
OnlineDriveBrowseFilesRequest,
OnlineDriveDownloadFileRequest,
)
from dify_plugin.entities.model import ModelType
from dify_plugin.entities.model import EmbeddingInputType, ModelType
from dify_plugin.entities.model.message import (
AssistantPromptMessage,
PromptMessage,
Expand All @@ -19,6 +19,7 @@
ToolPromptMessage,
UserPromptMessage,
)
from dify_plugin.entities.model.text_embedding import MultiModalContent
from dify_plugin.entities.provider_config import CredentialType
from dify_plugin.entities.trigger import Subscription

Expand Down Expand Up @@ -59,8 +60,10 @@ class ModelActions(StrEnum):
InvokeLLM = "invoke_llm"
GetLLMNumTokens = "get_llm_num_tokens"
InvokeTextEmbedding = "invoke_text_embedding"
InvokeMultimodalEmbedding = "invoke_multimodal_embedding"
GetTextEmbeddingNumTokens = "get_text_embedding_num_tokens"
InvokeRerank = "invoke_rerank"
InvokeMultimodalRerank = "invoke_multimodal_rerank"
InvokeTTS = "invoke_tts"
GetTTSVoices = "get_tts_model_voices"
InvokeSpeech2Text = "invoke_speech2text"
Expand Down Expand Up @@ -202,6 +205,14 @@ class ModelInvokeTextEmbeddingRequest(PluginAccessModelRequest):
texts: list[str]


class ModelInvokeMultimodalEmbeddingRequest(PluginAccessModelRequest):
action: ModelActions = ModelActions.InvokeMultimodalEmbedding
model_type: ModelType = ModelType.TEXT_EMBEDDING

documents: list[MultiModalContent]
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT


class ModelGetTextEmbeddingNumTokens(PluginAccessModelRequest):
action: ModelActions = ModelActions.GetTextEmbeddingNumTokens

Expand All @@ -217,6 +228,16 @@ class ModelInvokeRerankRequest(PluginAccessModelRequest):
top_n: int | None


class ModelInvokeMultimodalRerankRequest(PluginAccessModelRequest):
action: ModelActions = ModelActions.InvokeMultimodalRerank
model_type: ModelType = ModelType.RERANK

query: MultiModalContent
docs: Sequence[MultiModalContent]
score_threshold: float | None
top_n: int | None


class ModelInvokeTTSRequest(PluginAccessModelRequest):
action: ModelActions = ModelActions.InvokeTTS

Expand Down
28 changes: 28 additions & 0 deletions python/dify_plugin/core/plugin_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
ModelGetTTSVoices,
ModelInvokeLLMRequest,
ModelInvokeModerationRequest,
ModelInvokeMultimodalEmbeddingRequest,
ModelInvokeMultimodalRerankRequest,
ModelInvokeRerankRequest,
ModelInvokeSpeech2TextRequest,
ModelInvokeTextEmbeddingRequest,
Expand Down Expand Up @@ -219,6 +221,18 @@ def invoke_text_embedding(self, session: Session, data: ModelInvokeTextEmbedding
else:
raise ValueError(f"Model `{data.model_type}` not found for provider `{data.provider}`")

def invoke_multimodal_embedding(self, session: Session, data: ModelInvokeMultimodalEmbeddingRequest):
model_instance = self.registration.get_model_instance(data.provider, data.model_type)
if isinstance(model_instance, TextEmbeddingModel):
return model_instance.invoke_multimodal(
data.model,
data.credentials,
data.documents,
user=data.user_id,
input_type=data.input_type,
)
raise ValueError(f"Model `{data.model_type}` not found for provider `{data.provider}`")

def get_text_embedding_num_tokens(self, session: Session, data: ModelGetTextEmbeddingNumTokens):
model_instance = self.registration.get_model_instance(data.provider, data.model_type)
if isinstance(model_instance, TextEmbeddingModel):
Expand Down Expand Up @@ -247,6 +261,20 @@ def invoke_rerank(self, session: Session, data: ModelInvokeRerankRequest):
else:
raise ValueError(f"Model `{data.model_type}` not found for provider `{data.provider}`")

def invoke_multimodal_rerank(self, session: Session, data: ModelInvokeMultimodalRerankRequest):
model_instance = self.registration.get_model_instance(data.provider, data.model_type)
if isinstance(model_instance, RerankModel):
return model_instance.invoke_multimodal(
data.model,
data.credentials,
data.query,
data.docs,
score_threshold=data.score_threshold,
top_n=data.top_n,
user=data.user_id,
)
raise ValueError(f"Model `{data.model_type}` not found for provider `{data.provider}`")

def invoke_tts(self, session: Session, data: ModelInvokeTTSRequest):
model_instance = self.registration.get_model_instance(data.provider, data.model_type)
if isinstance(model_instance, TTSModel):
Expand Down
25 changes: 24 additions & 1 deletion python/dify_plugin/entities/model/rerank.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field

from dify_plugin.entities.model import BaseModelConfig, ModelType

Expand Down Expand Up @@ -32,3 +32,26 @@ class RerankModelConfig(BaseModelConfig):
top_n: int

model_config = ConfigDict(protected_namespaces=())


class MultiModalRerankResult(BaseModel):
"""Rerank response produced by a multimodal rerank model."""

model: str = Field(..., description="Identifier of the model producing the reranked documents.")
docs: list[RerankDocument] = Field(..., description="Reranked documents with scores.")


class MultiModalRerankModelConfig(BaseModelConfig):
"""Configuration payload for invoking a multimodal rerank model."""

model_type: ModelType = ModelType.RERANK
score_threshold: float | None = Field(
default=None,
description="Optional threshold for filtering documents based on score.",
)
top_n: int | None = Field(
default=None,
description="Optional limit on the number of documents returned.",
)

model_config = ConfigDict(protected_namespaces=())
34 changes: 33 additions & 1 deletion python/dify_plugin/entities/model/text_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from decimal import Decimal
from enum import StrEnum

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field

from dify_plugin.entities.model import BaseModelConfig, ModelType, ModelUsage

Expand Down Expand Up @@ -37,3 +38,34 @@ class TextEmbeddingModelConfig(BaseModelConfig):
model_type: ModelType = ModelType.TEXT_EMBEDDING

model_config = ConfigDict(protected_namespaces=())


class MultiModalContentType(StrEnum):
"""Supported content types for multimodal inputs."""

TEXT = "text"
IMAGE = "image"


class MultiModalContent(BaseModel):
"""A multimodal content payload provided by the caller."""

content: str = Field(..., description="The payload content, plain text or base64 encoded file data.")
content_type: MultiModalContentType = Field(..., description="The modality of the provided content.")


class MultiModalEmbeddingResult(BaseModel):
"""Embedding response produced by a multimodal embedding model."""

model: str = Field(..., description="Identifier of the model generating embeddings.")
embeddings: list[list[float]] = Field(..., description="Embedding vectors for provided contents.")
usage: EmbeddingUsage = Field(..., description="Usage metrics associated with the inference.")


class MultiModalEmbeddingModelConfig(BaseModelConfig):
"""Configuration payload for invoking a multimodal embedding model."""

model_type: ModelType = ModelType.TEXT_EMBEDDING
tenant_id: str = Field(..., description="Vendor tenant identifier associated with the dataset.")

model_config = ConfigDict(protected_namespaces=())
49 changes: 48 additions & 1 deletion python/dify_plugin/interfaces/model/rerank_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from abc import abstractmethod
from collections.abc import Sequence

from dify_plugin.entities.model import ModelType
from dify_plugin.entities.model.rerank import RerankResult
from dify_plugin.entities.model.rerank import MultiModalRerankResult, RerankResult
from dify_plugin.entities.model.text_embedding import MultiModalContent
from dify_plugin.interfaces.model.ai_model import AIModel


Expand Down Expand Up @@ -41,6 +43,23 @@ def _invoke(
"""
raise NotImplementedError

def _invoke_multimodal(
self,
model: str,
credentials: dict,
query: MultiModalContent,
docs: Sequence[MultiModalContent],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> MultiModalRerankResult:
"""Invoke a multimodal rerank model."""

raise NotImplementedError(
f"{self.__class__.__name__} does not implement `_invoke_multimodal`. "
"Implement this method to support multimodal rerank invocations."
)

############################################################
# For executor use only #
############################################################
Expand Down Expand Up @@ -73,3 +92,31 @@ def invoke(
return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
except Exception as e:
raise self._transform_invoke_error(e) from e

def invoke_multimodal(
self,
model: str,
credentials: dict,
query: MultiModalContent,
docs: Sequence[MultiModalContent],
score_threshold: float | None = None,
top_n: int | None = None,
user: str | None = None,
) -> MultiModalRerankResult:
"""Invoke a multimodal rerank model."""

with self.timing_context():
try:
return self._invoke_multimodal(
model,
credentials,
query,
docs,
score_threshold,
top_n,
user,
)
except NotImplementedError:
raise
except Exception as e:
raise self._transform_invoke_error(e) from e
45 changes: 44 additions & 1 deletion python/dify_plugin/interfaces/model/text_embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from pydantic import ConfigDict

from dify_plugin.entities.model import EmbeddingInputType, ModelPropertyKey, ModelType
from dify_plugin.entities.model.text_embedding import TextEmbeddingResult
from dify_plugin.entities.model.text_embedding import (
MultiModalContent,
MultiModalEmbeddingResult,
TextEmbeddingResult,
)
from dify_plugin.interfaces.model.ai_model import AIModel


Expand Down Expand Up @@ -42,6 +46,21 @@ def _invoke(
"""
raise NotImplementedError

def _invoke_multimodal(
self,
model: str,
credentials: dict,
documents: list[MultiModalContent],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> MultiModalEmbeddingResult:
"""Invoke a multimodal embedding model."""

raise NotImplementedError(
f"{self.__class__.__name__} does not implement `_invoke_multimodal`. "
"Implement this method to support multimodal embeddings."
)

@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]:
"""
Expand Down Expand Up @@ -115,3 +134,27 @@ def invoke(
return self._invoke(model, credentials, texts, user, input_type)
except Exception as e:
raise self._transform_invoke_error(e) from e

def invoke_multimodal(
self,
model: str,
credentials: dict,
documents: list[MultiModalContent],
user: str | None = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> MultiModalEmbeddingResult:
"""Invoke a multimodal embedding model."""

with self.timing_context():
try:
return self._invoke_multimodal(
model,
credentials,
documents,
user,
input_type,
)
except NotImplementedError:
raise
except Exception as e:
raise self._transform_invoke_error(e) from e
29 changes: 28 additions & 1 deletion python/dify_plugin/invocations/model/rerank.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from dify_plugin.core.entities.invocation import InvokeType
from dify_plugin.core.runtime import BackwardsInvocation
from dify_plugin.entities.model.rerank import RerankModelConfig, RerankResult
from dify_plugin.entities.model.rerank import (
MultiModalRerankModelConfig,
MultiModalRerankResult,
RerankModelConfig,
RerankResult,
)
from dify_plugin.entities.model.text_embedding import MultiModalContent


class RerankInvocation(BackwardsInvocation[RerankResult]):
Expand All @@ -20,3 +26,24 @@ def invoke(self, model_config: RerankModelConfig, docs: list[str], query: str) -
return data

raise Exception("No response from rerank")

def invoke_multimodal(
self,
model_config: MultiModalRerankModelConfig,
query: MultiModalContent,
docs: list[MultiModalContent],
) -> MultiModalRerankResult:
payload = {
**model_config.model_dump(),
"query": query.model_dump() if isinstance(query, MultiModalContent) else query,
"docs": [doc.model_dump() if isinstance(doc, MultiModalContent) else doc for doc in docs],
Comment on lines +38 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hints for query and docs are MultiModalContent and list[MultiModalContent] respectively. Pydantic ensures that these will be instances of MultiModalContent. Therefore, the isinstance checks are redundant and can be removed to simplify the code.

Suggested change
"query": query.model_dump() if isinstance(query, MultiModalContent) else query,
"docs": [doc.model_dump() if isinstance(doc, MultiModalContent) else doc for doc in docs],
"query": query.model_dump(),
"docs": [doc.model_dump() for doc in docs],

}

for data in self._backwards_invoke(
InvokeType.MultimodalRerank,
MultiModalRerankResult,
payload,
):
return data

raise Exception("No response from multimodal rerank")
Loading