Skip to content

Changed Mem0 Storage v1.1 -> v2 #2893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/concepts/memory.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,16 @@ crew = Crew(
```

### Advanced Mem0 Configuration
When using Mem0 Client, you can customize the memory configuration further, by using parameters like 'includes', 'excludes', 'custom_categories' and 'run_id' (this is only for short-term memory).
You can find more details in the [Mem0 documentation](https://docs.mem0.ai/).
```python

new_categories = [
{"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"},
{"seeking_structure": "Documents goals around creating routines, schedules, and organized systems in various life areas"},
{"personal_information": "Basic information about the user including name, preferences, and personality traits"}
]

crew = Crew(
agents=[...],
tasks=[...],
Expand All @@ -732,6 +741,10 @@ crew = Crew(
"org_id": "my_org_id", # Optional
"project_id": "my_project_id", # Optional
"api_key": "custom-api-key" # Optional - overrides env var
"run_id": "my_run_id", # Optional - for short-term memory
"includes": "include1", # Optional
"excludes": "exclude1", # Optional
"custom_categories": new_categories # Optional - custom categories for user memory
},
"user_memory": {}
}
Expand Down
192 changes: 102 additions & 90 deletions src/crewai/memory/storage/mem0_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,67 @@ class Mem0Storage(Storage):
"""
Extends Storage to handle embedding and searching across entities using Mem0.
"""

def __init__(self, type, crew=None, config=None):
super().__init__()
supported_types = ["user", "short_term", "long_term", "entities", "external"]
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: "
+ ", ".join(supported_types)
)

self._validate_type(type)
self.memory_type = type
self.crew = crew
self.config = config or {}

# TODO: Memory config will be removed in the future the config will be passed as a parameter
self.memory_config = self.config or getattr(crew, "memory_config", {}) or {}
self.config = config or getattr(crew, "memory_config", {}).get("config", {}) or {}

self._validate_user_id()
self._extract_config_values()
self._initialize_memory()

# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id()
if type == "user" and not user_id:
def _validate_type(self, type):
supported_types = {"user", "short_term", "long_term", "entities", "external"}
if type not in supported_types:
raise ValueError(
f"Invalid type '{type}' for Mem0Storage. Must be one of: {', '.join(supported_types)}"
)

def _validate_user_id(self):
if self.memory_type == "user" and not self.config.get("user_id", ""):
raise ValueError("User ID is required for user memory type")

# API key in memory config overrides the environment variable
config = self._get_config()
mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY")
mem0_org_id = config.get("org_id")
mem0_project_id = config.get("project_id")
mem0_local_config = config.get("local_mem0_config")

# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
if mem0_api_key:
if mem0_org_id and mem0_project_id:
self.memory = MemoryClient(
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
)
else:
self.memory = MemoryClient(api_key=mem0_api_key)
def _extract_config_values(self):
cfg = self.config
self.mem0_run_id = cfg.get("run_id")
self.includes = cfg.get("includes")
self.excludes = cfg.get("excludes")
self.custom_categories = cfg.get("custom_categories")

def _initialize_memory(self):
api_key = self.config.get("api_key") or os.getenv("MEM0_API_KEY")
org_id = self.config.get("org_id")
project_id = self.config.get("project_id")
local_config = self.config.get("local_mem0_config")

if api_key:
self.memory = (
MemoryClient(api_key=api_key, org_id=org_id, project_id=project_id)
if org_id and project_id
else MemoryClient(api_key=api_key)
)
if self.custom_categories:
self.memory.update_project(custom_categories=self.custom_categories)
else:
if mem0_local_config and len(mem0_local_config):
self.memory = Memory.from_config(mem0_local_config)
else:
self.memory = Memory()
self.memory = (
Memory.from_config(local_config)
if local_config and len(local_config)
else Memory()
)

def _get_agent_name(self) -> str:
if not self.crew:
return ""

agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents,max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)

def _sanitize_role(self, role: str) -> str:
"""
Expand All @@ -62,86 +82,78 @@ def _sanitize_role(self, role: str) -> str:
return role.replace("\n", "").replace(" ", "_").replace("/", "_")

def save(self, value: Any, metadata: Dict[str, Any]) -> None:
user_id = self._get_user_id()
user_id = self.config.get("user_id", "")
agent_name = self._get_agent_name()
params = None

base_metadata = {
"short_term": "short_term",
"long_term": "long_term",
"entities": "entity",
"external": "external"
}

# Shared base params
params: dict[str, Any] = {
"agent_id": agent_name,
"metadata": {"type": base_metadata[self.memory_type], **metadata}
}

# Type-specific overrides
if self.memory_type == "short_term":
params = {
"agent_id": agent_name,
"infer": False,
"metadata": {"type": "short_term", **metadata},
}
params["infer"] = False
elif self.memory_type == "long_term":
params = {
"agent_id": agent_name,
"infer": False,
"metadata": {"type": "long_term", **metadata},
}
params["infer"] = False
elif self.memory_type == "entities":
params = {
"agent_id": agent_name,
"infer": False,
"metadata": {"type": "entity", **metadata},
}
params["infer"] = False
elif self.memory_type == "external":
params = {
"user_id": user_id,
"agent_id": agent_name,
"metadata": {"type": "external", **metadata},
}
params["user_id"] = user_id


if params:
# MemoryClient-specific overrides
if isinstance(self.memory, MemoryClient):
params["output_format"] = "v1.1"
params["version"] = "v2"
params["includes"] = self.includes
params["excludes"] = self.excludes

if self.memory_type == "short_term":
params["run_id"] = self.mem0_run_id

self.memory.add(value, **params)

def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
) -> List[Any]:
params = {"query": query, "limit": limit, "output_format": "v1.1"}
if user_id := self._get_user_id():
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
params = {
"query": query,
"limit": limit,
"version": "v2"
}

if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id

agent_name = self._get_agent_name()
if self.memory_type == "short_term":
params["agent_id"] = agent_name
params["metadata"] = {"type": "short_term"}
elif self.memory_type == "long_term":
params["agent_id"] = agent_name
params["metadata"] = {"type": "long_term"}
elif self.memory_type == "entities":
params["agent_id"] = agent_name
params["metadata"] = {"type": "entity"}
elif self.memory_type == "external":
params["agent_id"] = agent_name
params["metadata"] = {"type": "external"}
params["agent_id"] = agent_name

memory_type_map = {
"short_term": {"type": "short_term"},
"long_term": {"type": "long_term"},
"entities": {"type": "entity"},
"external": {"type": "external"},
}

if self.memory_type in memory_type_map:
params["metadata"] = memory_type_map[self.memory_type]
if self.memory_type == "short_term":
params["run_id"] = self.mem0_run_id

# Discard the filters for now since we create the filters
# automatically when the crew is created.
if isinstance(self.memory, Memory):
del params["metadata"], params["output_format"]
del params["metadata"], params["version"], params["run_id"]

results = self.memory.search(**params)
return [r for r in results["results"] if r["score"] >= score_threshold]

def _get_user_id(self) -> str:
return self._get_config().get("user_id", "")

def _get_agent_name(self) -> str:
if not self.crew:
return ""

agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents,max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)

def _get_config(self) -> Dict[str, Any]:
return self.config or getattr(self, "memory_config", {}).get("config", {}) or {}

def reset(self):
if self.memory:
self.memory.reset()
50 changes: 40 additions & 10 deletions tests/storage/test_mem0_storage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import os
from unittest.mock import MagicMock, patch

import pytest
from mem0.client.main import MemoryClient
from mem0.memory.main import Memory

from crewai.agent import Agent
from crewai.crew import Crew
from crewai.memory.storage.mem0_storage import Mem0Storage
from crewai.task import Task


# Define the class (if not already defined)
Expand Down Expand Up @@ -59,10 +55,11 @@ def mem0_storage_with_mocked_config(mock_mem0_memory):
}

# Instantiate the class with memory_config
# Parameters like run_id, includes, and excludes doesn't matter in Memory OSS
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {"user_id": "test_user", "local_mem0_config": config},
"config": {"user_id": "test_user", "local_mem0_config": config, "run_id": "my_run_id", "includes": "include1","excludes": "exclude1"},
}
)

Expand Down Expand Up @@ -99,6 +96,9 @@ def mem0_storage_with_memory_client_using_config_from_crew(mock_mem0_memory_clie
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"run_id": "my_run_id",
"includes": "include1",
"excludes": "exclude1",
},
}
)
Expand Down Expand Up @@ -154,11 +154,37 @@ def test_mem0_storage_with_explict_config(
assert (
mem0_storage_with_memory_client_using_explictly_config.config == expected_config
)
assert (
mem0_storage_with_memory_client_using_explictly_config.memory_config
== expected_config


def test_mem0_storage_updates_project_with_custom_categories(mock_mem0_memory_client):
mock_mem0_memory_client.update_project = MagicMock()

new_categories = [
{"lifestyle_management_concerns": "Tracks daily routines, habits, hobbies and interests including cooking, time management and work-life balance"},
]

crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"org_id": "my_org_id",
"project_id": "my_project_id",
"custom_categories": new_categories,
},
}
)

with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
_ = Mem0Storage(type="short_term", crew=crew)

mock_mem0_memory_client.update_project.assert_called_once_with(
custom_categories=new_categories
)




def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test save method for different memory types"""
Expand Down Expand Up @@ -195,7 +221,10 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co
agent_id="Test_Agent",
infer=False,
metadata={"type": "short_term", "key": "value"},
output_format="v1.1"
version="v2",
run_id="my_run_id",
includes="include1",
excludes="exclude1",
)


Expand Down Expand Up @@ -232,7 +261,8 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
agent_id="Test_Agent",
metadata={"type": "short_term"},
user_id="test_user",
output_format='v1.1'
version='v2',
run_id="my_run_id",
)

assert len(results) == 1
Expand Down
Loading