Skip to content

Commit d60ad8b

Browse files
authored
feat: modify mem-reader prompt (#273)
* feat: timeout for nebula query 5s->10s * feat: exclude heavy feilds when calling memories from nebula db * test: fix tree-text-mem searcher text * feat: adjust prompt * feat: adjust prompt
1 parent 25f7a5a commit d60ad8b

File tree

3 files changed

+339
-129
lines changed

3 files changed

+339
-129
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 148 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from contextlib import suppress
55
from datetime import datetime
6+
from queue import Empty, Queue
67
from threading import Lock
78
from typing import TYPE_CHECKING, Any, ClassVar, Literal
89

@@ -86,6 +87,137 @@ def _normalize_datetime(val):
8687
return str(val)
8788

8889

90+
class SessionPoolError(Exception):
91+
pass
92+
93+
94+
class SessionPool:
95+
@require_python_package(
96+
import_name="nebulagraph_python",
97+
install_command="pip install ... @Tianxing",
98+
install_link=".....",
99+
)
100+
def __init__(
101+
self,
102+
hosts: list[str],
103+
user: str,
104+
password: str,
105+
minsize: int = 1,
106+
maxsize: int = 10000,
107+
):
108+
self.hosts = hosts
109+
self.user = user
110+
self.password = password
111+
self.minsize = minsize
112+
self.maxsize = maxsize
113+
self.pool = Queue(maxsize)
114+
self.lock = Lock()
115+
116+
self.clients = []
117+
118+
for _ in range(minsize):
119+
self._create_and_add_client()
120+
121+
@timed
122+
def _create_and_add_client(self):
123+
from nebulagraph_python import NebulaClient
124+
125+
client = NebulaClient(self.hosts, self.user, self.password)
126+
self.pool.put(client)
127+
self.clients.append(client)
128+
129+
@timed
130+
def get_client(self, timeout: float = 5.0):
131+
try:
132+
return self.pool.get(timeout=timeout)
133+
except Empty:
134+
with self.lock:
135+
if len(self.clients) < self.maxsize:
136+
from nebulagraph_python import NebulaClient
137+
138+
client = NebulaClient(self.hosts, self.user, self.password)
139+
self.clients.append(client)
140+
return client
141+
raise RuntimeError("NebulaClientPool exhausted") from None
142+
143+
@timed
144+
def return_client(self, client):
145+
try:
146+
client.execute("YIELD 1")
147+
self.pool.put(client)
148+
except Exception:
149+
logger.info("[Pool] Client dead, replacing...")
150+
self.replace_client(client)
151+
152+
@timed
153+
def close(self):
154+
for client in self.clients:
155+
with suppress(Exception):
156+
client.close()
157+
self.clients.clear()
158+
159+
@timed
160+
def get(self):
161+
"""
162+
Context manager: with pool.get() as client:
163+
"""
164+
165+
class _ClientContext:
166+
def __init__(self, outer):
167+
self.outer = outer
168+
self.client = None
169+
170+
def __enter__(self):
171+
self.client = self.outer.get_client()
172+
return self.client
173+
174+
def __exit__(self, exc_type, exc_val, exc_tb):
175+
if self.client:
176+
self.outer.return_client(self.client)
177+
178+
return _ClientContext(self)
179+
180+
@timed
181+
def reset_pool(self):
182+
"""⚠️ Emergency reset: Close all clients and clear the pool."""
183+
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
184+
with self.lock:
185+
for client in self.clients:
186+
try:
187+
client.close()
188+
except Exception:
189+
logger.error("Fail to close!!!")
190+
self.clients.clear()
191+
while not self.pool.empty():
192+
try:
193+
self.pool.get_nowait()
194+
except Empty:
195+
break
196+
for _ in range(self.minsize):
197+
self._create_and_add_client()
198+
logger.info("[Pool] Pool has been reset successfully.")
199+
200+
@timed
201+
def replace_client(self, client):
202+
try:
203+
client.close()
204+
except Exception:
205+
logger.error("Fail to close client")
206+
207+
if client in self.clients:
208+
self.clients.remove(client)
209+
210+
from nebulagraph_python import NebulaClient
211+
212+
new_client = NebulaClient(self.hosts, self.user, self.password)
213+
self.clients.append(new_client)
214+
215+
self.pool.put(new_client)
216+
217+
logger.info("[Pool] Replaced dead client with a new one.")
218+
return new_client
219+
220+
89221
class NebulaGraphDB(BaseGraphDB):
90222
"""
91223
NebulaGraph-based implementation of a graph memory store.
@@ -125,19 +257,18 @@ def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
125257
Get a shared NebulaPool from cache or create one if missing.
126258
Thread-safe with a lock; maintains a simple refcount.
127259
"""
128-
from nebulagraph_python.client.pool import NebulaPool, NebulaPoolConfig
129-
130260
key = cls._make_pool_key(cfg)
131261

132262
with cls._POOL_LOCK:
133263
pool = cls._POOL_CACHE.get(key)
134264
if pool is None:
135265
# Create a new pool and put into cache
136-
pool = NebulaPool(
266+
pool = SessionPool(
137267
hosts=cfg.get("uri"),
138-
username=cfg.get("user"),
268+
user=cfg.get("user"),
139269
password=cfg.get("password"),
140-
pool_config=NebulaPoolConfig(max_client_size=cfg.get("max_client", 1000)),
270+
minsize=1,
271+
maxsize=cfg.get("max_client", 1000),
141272
)
142273
cls._POOL_CACHE[key] = pool
143274
cls._POOL_REFCOUNT[key] = 0
@@ -256,17 +387,18 @@ def __init__(self, config: NebulaGraphDBConfig):
256387

257388
@timed
258389
def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
259-
needs_use_prefix = ("SESSION SET GRAPH" not in gql) and ("USE " not in gql)
260-
use_prefix = f"USE `{self.db_name}` " if auto_set_db and needs_use_prefix else ""
261-
262-
ngql = use_prefix + gql
390+
with self.pool.get() as client:
391+
try:
392+
if auto_set_db and self.db_name:
393+
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
394+
return client.execute(gql, timeout=timeout)
263395

264-
try:
265-
with self.pool.borrow() as client:
266-
return client.execute(ngql, timeout=timeout)
267-
except Exception as e:
268-
logger.error(f"[execute_query] Failed: {e}")
269-
raise
396+
except Exception as e:
397+
if "Session not found" in str(e) or "Connection not established" in str(e):
398+
logger.warning(f"[execute_query] {e!s}, replacing client...")
399+
self.pool.replace_client(client)
400+
return self.execute_query(gql, timeout, auto_set_db)
401+
raise
270402

271403
@timed
272404
def close(self):
@@ -940,20 +1072,12 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
9401072
"""
9411073
where_clauses = []
9421074

943-
def _escape_value(value):
944-
if isinstance(value, str):
945-
return f'"{value}"'
946-
elif isinstance(value, list):
947-
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
948-
else:
949-
return str(value)
950-
9511075
for _i, f in enumerate(filters):
9521076
field = f["field"]
9531077
op = f.get("op", "=")
9541078
value = f["value"]
9551079

956-
escaped_value = _escape_value(value)
1080+
escaped_value = self._format_value(value)
9571081

9581082
# Build WHERE clause
9591083
if op == "=":

0 commit comments

Comments
 (0)