|
3 | 3 |
|
4 | 4 | from contextlib import suppress
|
5 | 5 | from datetime import datetime
|
| 6 | +from queue import Empty, Queue |
6 | 7 | from threading import Lock
|
7 | 8 | from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
8 | 9 |
|
@@ -86,6 +87,137 @@ def _normalize_datetime(val):
|
86 | 87 | return str(val)
|
87 | 88 |
|
88 | 89 |
|
| 90 | +class SessionPoolError(Exception): |
| 91 | + pass |
| 92 | + |
| 93 | + |
| 94 | +class SessionPool: |
| 95 | + @require_python_package( |
| 96 | + import_name="nebulagraph_python", |
| 97 | + install_command="pip install ... @Tianxing", |
| 98 | + install_link=".....", |
| 99 | + ) |
| 100 | + def __init__( |
| 101 | + self, |
| 102 | + hosts: list[str], |
| 103 | + user: str, |
| 104 | + password: str, |
| 105 | + minsize: int = 1, |
| 106 | + maxsize: int = 10000, |
| 107 | + ): |
| 108 | + self.hosts = hosts |
| 109 | + self.user = user |
| 110 | + self.password = password |
| 111 | + self.minsize = minsize |
| 112 | + self.maxsize = maxsize |
| 113 | + self.pool = Queue(maxsize) |
| 114 | + self.lock = Lock() |
| 115 | + |
| 116 | + self.clients = [] |
| 117 | + |
| 118 | + for _ in range(minsize): |
| 119 | + self._create_and_add_client() |
| 120 | + |
| 121 | + @timed |
| 122 | + def _create_and_add_client(self): |
| 123 | + from nebulagraph_python import NebulaClient |
| 124 | + |
| 125 | + client = NebulaClient(self.hosts, self.user, self.password) |
| 126 | + self.pool.put(client) |
| 127 | + self.clients.append(client) |
| 128 | + |
| 129 | + @timed |
| 130 | + def get_client(self, timeout: float = 5.0): |
| 131 | + try: |
| 132 | + return self.pool.get(timeout=timeout) |
| 133 | + except Empty: |
| 134 | + with self.lock: |
| 135 | + if len(self.clients) < self.maxsize: |
| 136 | + from nebulagraph_python import NebulaClient |
| 137 | + |
| 138 | + client = NebulaClient(self.hosts, self.user, self.password) |
| 139 | + self.clients.append(client) |
| 140 | + return client |
| 141 | + raise RuntimeError("NebulaClientPool exhausted") from None |
| 142 | + |
| 143 | + @timed |
| 144 | + def return_client(self, client): |
| 145 | + try: |
| 146 | + client.execute("YIELD 1") |
| 147 | + self.pool.put(client) |
| 148 | + except Exception: |
| 149 | + logger.info("[Pool] Client dead, replacing...") |
| 150 | + self.replace_client(client) |
| 151 | + |
| 152 | + @timed |
| 153 | + def close(self): |
| 154 | + for client in self.clients: |
| 155 | + with suppress(Exception): |
| 156 | + client.close() |
| 157 | + self.clients.clear() |
| 158 | + |
| 159 | + @timed |
| 160 | + def get(self): |
| 161 | + """ |
| 162 | + Context manager: with pool.get() as client: |
| 163 | + """ |
| 164 | + |
| 165 | + class _ClientContext: |
| 166 | + def __init__(self, outer): |
| 167 | + self.outer = outer |
| 168 | + self.client = None |
| 169 | + |
| 170 | + def __enter__(self): |
| 171 | + self.client = self.outer.get_client() |
| 172 | + return self.client |
| 173 | + |
| 174 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 175 | + if self.client: |
| 176 | + self.outer.return_client(self.client) |
| 177 | + |
| 178 | + return _ClientContext(self) |
| 179 | + |
| 180 | + @timed |
| 181 | + def reset_pool(self): |
| 182 | + """⚠️ Emergency reset: Close all clients and clear the pool.""" |
| 183 | + logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.") |
| 184 | + with self.lock: |
| 185 | + for client in self.clients: |
| 186 | + try: |
| 187 | + client.close() |
| 188 | + except Exception: |
| 189 | + logger.error("Fail to close!!!") |
| 190 | + self.clients.clear() |
| 191 | + while not self.pool.empty(): |
| 192 | + try: |
| 193 | + self.pool.get_nowait() |
| 194 | + except Empty: |
| 195 | + break |
| 196 | + for _ in range(self.minsize): |
| 197 | + self._create_and_add_client() |
| 198 | + logger.info("[Pool] Pool has been reset successfully.") |
| 199 | + |
| 200 | + @timed |
| 201 | + def replace_client(self, client): |
| 202 | + try: |
| 203 | + client.close() |
| 204 | + except Exception: |
| 205 | + logger.error("Fail to close client") |
| 206 | + |
| 207 | + if client in self.clients: |
| 208 | + self.clients.remove(client) |
| 209 | + |
| 210 | + from nebulagraph_python import NebulaClient |
| 211 | + |
| 212 | + new_client = NebulaClient(self.hosts, self.user, self.password) |
| 213 | + self.clients.append(new_client) |
| 214 | + |
| 215 | + self.pool.put(new_client) |
| 216 | + |
| 217 | + logger.info("[Pool] Replaced dead client with a new one.") |
| 218 | + return new_client |
| 219 | + |
| 220 | + |
89 | 221 | class NebulaGraphDB(BaseGraphDB):
|
90 | 222 | """
|
91 | 223 | NebulaGraph-based implementation of a graph memory store.
|
@@ -125,19 +257,18 @@ def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
|
125 | 257 | Get a shared NebulaPool from cache or create one if missing.
|
126 | 258 | Thread-safe with a lock; maintains a simple refcount.
|
127 | 259 | """
|
128 |
| - from nebulagraph_python.client.pool import NebulaPool, NebulaPoolConfig |
129 |
| - |
130 | 260 | key = cls._make_pool_key(cfg)
|
131 | 261 |
|
132 | 262 | with cls._POOL_LOCK:
|
133 | 263 | pool = cls._POOL_CACHE.get(key)
|
134 | 264 | if pool is None:
|
135 | 265 | # Create a new pool and put into cache
|
136 |
| - pool = NebulaPool( |
| 266 | + pool = SessionPool( |
137 | 267 | hosts=cfg.get("uri"),
|
138 |
| - username=cfg.get("user"), |
| 268 | + user=cfg.get("user"), |
139 | 269 | password=cfg.get("password"),
|
140 |
| - pool_config=NebulaPoolConfig(max_client_size=cfg.get("max_client", 1000)), |
| 270 | + minsize=1, |
| 271 | + maxsize=cfg.get("max_client", 1000), |
141 | 272 | )
|
142 | 273 | cls._POOL_CACHE[key] = pool
|
143 | 274 | cls._POOL_REFCOUNT[key] = 0
|
@@ -256,17 +387,18 @@ def __init__(self, config: NebulaGraphDBConfig):
|
256 | 387 |
|
257 | 388 | @timed
|
258 | 389 | def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
|
259 |
| - needs_use_prefix = ("SESSION SET GRAPH" not in gql) and ("USE " not in gql) |
260 |
| - use_prefix = f"USE `{self.db_name}` " if auto_set_db and needs_use_prefix else "" |
261 |
| - |
262 |
| - ngql = use_prefix + gql |
| 390 | + with self.pool.get() as client: |
| 391 | + try: |
| 392 | + if auto_set_db and self.db_name: |
| 393 | + client.execute(f"SESSION SET GRAPH `{self.db_name}`") |
| 394 | + return client.execute(gql, timeout=timeout) |
263 | 395 |
|
264 |
| - try: |
265 |
| - with self.pool.borrow() as client: |
266 |
| - return client.execute(ngql, timeout=timeout) |
267 |
| - except Exception as e: |
268 |
| - logger.error(f"[execute_query] Failed: {e}") |
269 |
| - raise |
| 396 | + except Exception as e: |
| 397 | + if "Session not found" in str(e) or "Connection not established" in str(e): |
| 398 | + logger.warning(f"[execute_query] {e!s}, replacing client...") |
| 399 | + self.pool.replace_client(client) |
| 400 | + return self.execute_query(gql, timeout, auto_set_db) |
| 401 | + raise |
270 | 402 |
|
271 | 403 | @timed
|
272 | 404 | def close(self):
|
@@ -940,20 +1072,12 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
|
940 | 1072 | """
|
941 | 1073 | where_clauses = []
|
942 | 1074 |
|
943 |
| - def _escape_value(value): |
944 |
| - if isinstance(value, str): |
945 |
| - return f'"{value}"' |
946 |
| - elif isinstance(value, list): |
947 |
| - return "[" + ", ".join(_escape_value(v) for v in value) + "]" |
948 |
| - else: |
949 |
| - return str(value) |
950 |
| - |
951 | 1075 | for _i, f in enumerate(filters):
|
952 | 1076 | field = f["field"]
|
953 | 1077 | op = f.get("op", "=")
|
954 | 1078 | value = f["value"]
|
955 | 1079 |
|
956 |
| - escaped_value = _escape_value(value) |
| 1080 | + escaped_value = self._format_value(value) |
957 | 1081 |
|
958 | 1082 | # Build WHERE clause
|
959 | 1083 | if op == "=":
|
|
0 commit comments