Skip to content

Commit 2f9d8c5

Browse files
committed
feat: refactor ChatPromptBuilder to make use of extract_declared_variables
1 parent 3db96ab commit 2f9d8c5

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

haystack/components/builders/chat_prompt_builder.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from haystack import component, default_from_dict, default_to_dict, logging
1313
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent
1414
from haystack.lazy_imports import LazyImport
15-
from haystack.utils import Jinja2TimeExtension
15+
from haystack.utils import Jinja2TimeExtension, extract_declared_variables
1616
from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part
1717

1818
logger = logging.getLogger(__name__)
@@ -171,21 +171,28 @@ def __init__(
171171

172172
extracted_variables = []
173173
if template and not variables:
174+
175+
def _extract_from_text(text: str, role: Optional[str] = None, is_filter_allowed: bool = False) -> list:
176+
if text is None:
177+
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=role or "unknown", message=text))
178+
if is_filter_allowed and "templatize_part" in text:
179+
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
180+
181+
ast = self._env.parse(text)
182+
template_variables = meta.find_undeclared_variables(ast)
183+
assigned_variables = extract_declared_variables(text, env=self._env)
184+
return list(template_variables - assigned_variables)
185+
174186
if isinstance(template, list):
175187
for message in template:
176188
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
177-
# infer variables from template
178-
if message.text is None:
179-
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message))
180-
if message.text and "templatize_part" in message.text:
181-
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
182-
ast = self._env.parse(message.text)
183-
template_variables = meta.find_undeclared_variables(ast)
184-
extracted_variables += list(template_variables)
189+
extracted_variables += _extract_from_text(
190+
message.text, role=message.role.value, is_filter_allowed=True
191+
)
185192
elif isinstance(template, str):
186-
ast = self._env.parse(template)
187-
extracted_variables = list(meta.find_undeclared_variables(ast))
193+
extracted_variables = _extract_from_text(template, is_filter_allowed=False)
188194

195+
extracted_variables = extracted_variables or []
189196
self.variables = variables or extracted_variables
190197
self.required_variables = required_variables or []
191198

0 commit comments

Comments
 (0)