44
55from langchain_community .document_loaders import PyPDFLoader , TextLoader
66from langchain .text_splitter import RecursiveCharacterTextSplitter
7- from langchain .chains import ConversationalRetrievalChain
8- from langchain .memory import ChatMessageHistory , ConversationBufferMemory
7+ # from langchain.chains import ConversationalRetrievalChain
8+ # from langchain.memory import ConversationBufferMemory
9+ from langchain .memory import ChatMessageHistory
910from langchain_anthropic import ChatAnthropic
1011from langchain_community .vectorstores import Chroma
1112from langchain_community .embeddings import HuggingFaceEmbeddings
1617# For Approach 1
1718from langchain .chains .combine_documents import create_stuff_documents_chain
1819from langchain_core .prompts import ChatPromptTemplate , MessagesPlaceholder
19- from typing import Dict
20- from langchain_core .runnables import RunnablePassthrough
21- from langchain_core .messages import HumanMessage
20+ # from typing import Dict
21+ # from langchain_core.runnables import RunnablePassthrough
22+ # from langchain_core.messages import HumanMessage
2223
2324# For Approach 3
24- from langchain .chains .question_answering import load_qa_chain
25- from langchain .prompts import PromptTemplate
25+ # from langchain.chains.question_answering import load_qa_chain
26+ # from langchain.prompts import PromptTemplate
2627
28+ # For Aprroach 4
29+ from langchain .chains import create_history_aware_retriever , create_retrieval_chain
30+ from langchain_core .chat_history import BaseChatMessageHistory
31+ from langchain_core .runnables .history import RunnableWithMessageHistory
2732
2833ANTHROPIC_API_KEY = os .getenv ("ANTHROPIC_API_KEY" )
2934llm = ChatAnthropic (temperature = 0 , model_name = "claude-3-opus-20240229" )
@@ -72,18 +77,10 @@ def store_embeddings(chunks):
7277 print (f"Size of vectordb: { vectordb ._collection .count ()} " )
7378 return vectordb
7479
75- # Retrieve data from the query
76- # def simple_retrieval(vectordb, message: cl.Message):
77- # query = message.content
7880
79- # results = vectordb.similarity_search(query, k=5)
80- # retrieved_documents = [result.page_content for result in results]
81-
82- # return retrieved_documents
83-
84- @cl .step
85- def ChainOfThought ():
86- return "Not working yet"
81+ # @cl.step
82+ # def ChainOfThought():
83+ # return "Not working yet"
8784
8885
8986@cl .on_chat_start
@@ -156,23 +153,23 @@ async def start():
156153
157154 # Approach 2 (high level approach)
158155
159- message_history = ChatMessageHistory ()
156+ # message_history = ChatMessageHistory()
160157
161- memory = ConversationBufferMemory (
162- memory_key = "chat_history" ,
163- input_key = "question" ,
164- output_key = "answer" ,
165- chat_memory = message_history ,
166- return_messages = True ,
167- )
158+ # memory = ConversationBufferMemory(
159+ # memory_key="chat_history",
160+ # input_key="question",
161+ # output_key="answer",
162+ # chat_memory=message_history,
163+ # return_messages=True,
164+ # )
168165
169- chain = ConversationalRetrievalChain .from_llm (
170- llm ,
171- chain_type = "stuff" ,
172- retriever = vectordb .as_retriever (search_type = "similarity" , search_kwargs = {"k" :5 }),
173- memory = memory ,
174- return_source_documents = True ,
175- )
166+ # chain = ConversationalRetrievalChain.from_llm(
167+ # llm,
168+ # chain_type="stuff",
169+ # retriever=vectordb.as_retriever(search_type="similarity", search_kwargs={"k":5}),
170+ # memory=memory,
171+ # return_source_documents=True,
172+ # )
176173
177174 # debugging
178175 # if chain.memory is not None:
@@ -206,8 +203,59 @@ async def start():
206203 # cl.user_session.set("vectordb", vectordb)
207204
208205 # Approach 4 (Approach 1 + memory)
209-
210206
207+
208+ retriever = vectordb .as_retriever (search_type = "similarity" , search_kwargs = {"k" :5 })
209+
210+ ### Contextualize question ###
211+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
212+ which might reference context in the chat history, formulate a standalone question \
213+ which can be understood without the chat history. Do NOT answer the question, \
214+ just reformulate it if needed and otherwise return it as is."""
215+ contextualize_q_prompt = ChatPromptTemplate .from_messages (
216+ [
217+ ("system" , contextualize_q_system_prompt ),
218+ MessagesPlaceholder ("chat_history" ),
219+ ("human" , "{input}" ),
220+ ]
221+ )
222+ history_aware_retriever = create_history_aware_retriever (
223+ llm , retriever , contextualize_q_prompt
224+ )
225+ ### Answer question ###
226+ qa_system_prompt = """You are an assistant for question-answering tasks. \
227+ Use the following pieces of retrieved context to answer the question. \
228+ If the context doesn't contain any relevant information to the question, don't make something up and just say that you don't know. \
229+
230+ {context}"""
231+ qa_prompt = ChatPromptTemplate .from_messages (
232+ [
233+ ("system" , qa_system_prompt ),
234+ MessagesPlaceholder ("chat_history" ),
235+ ("human" , "{input}" ),
236+ ]
237+ )
238+ question_answer_chain = create_stuff_documents_chain (llm , qa_prompt )
239+
240+ rag_chain = create_retrieval_chain (history_aware_retriever , question_answer_chain )
241+
242+
243+ ### Statefully manage chat history ###
244+ store = {}
245+
246+ def get_session_history (session_id : str ) -> BaseChatMessageHistory :
247+ if session_id not in store :
248+ store [session_id ] = ChatMessageHistory ()
249+ return store [session_id ]
250+
251+ chain = RunnableWithMessageHistory (
252+ rag_chain ,
253+ get_session_history ,
254+ input_messages_key = "input" ,
255+ history_messages_key = "chat_history" ,
256+ output_messages_key = "answer" ,
257+ )
258+
211259 #COMMON TO ALL APPROACHES
212260 msg .content = f"`{ file .name } ` processed. You can now ask questions!"
213261 await msg .update ()
@@ -218,6 +266,7 @@ async def start():
218266@cl .on_message
219267async def main (message : cl .Message ):
220268 chain = cl .user_session .get ("chain" )
269+
221270 # Approach 1
222271 # response = chain.invoke({
223272 # "messages": [
@@ -227,33 +276,33 @@ async def main(message: cl.Message):
227276 # await cl.Message(response["answer"]).send()
228277
229278 # Approach 2
230- response = await chain .acall (message .content , callbacks = [cl .AsyncLangchainCallbackHandler ()])
231- print (response ) # debugging
232- answer = response ["answer" ]
233- source_documents = response ["source_documents" ]
234- text_elements = []
235- unique_pages = set ()
279+ # response = await chain.acall(message.content, callbacks=[cl.AsyncLangchainCallbackHandler()])
280+ # print(response) # debugging
281+ # answer = response["answer"]
282+ # source_documents = response["source_documents"]
283+ # text_elements = []
284+ # unique_pages = set()
236285
237- if source_documents :
238- for source_idx , source_doc in enumerate (source_documents ):
239- source_name = f"source_{ source_idx } "
240- page_number = source_doc .metadata ['page' ]
241- page = f"Page { page_number } "
242- text_element_content = source_doc .page_content
243- # text_elements.append(cl.Text(content=text_element_content, name=source_name))
244- if page not in unique_pages :
245- unique_pages .add (page )
246- text_elements .append (cl .Text (content = text_element_content , name = page ))
247- # text_elements.append(cl.Text(content=text_element_content, name=page))
248- source_names = [text_el .name for text_el in text_elements ]
249- if source_names :
250- answer += f"\n \n Sources: { ', ' .join (source_names )} "
251- else :
252- answer += "\n \n No sources found"
286+ # if source_documents:
287+ # for source_idx, source_doc in enumerate(source_documents):
288+ # source_name = f"source_{source_idx}"
289+ # page_number = source_doc.metadata['page']
290+ # page = f"Page {page_number}"
291+ # text_element_content = source_doc.page_content
292+ # # text_elements.append(cl.Text(content=text_element_content, name=source_name))
293+ # if page not in unique_pages:
294+ # unique_pages.add(page)
295+ # text_elements.append(cl.Text(content=text_element_content, name=page))
296+ # # text_elements.append(cl.Text(content=text_element_content, name=page))
297+ # source_names = [text_el.name for text_el in text_elements]
298+ # if source_names:
299+ # answer += f"\n\nSources: {', '.join(source_names)}"
300+ # else:
301+ # answer += "\n\nNo sources found"
253302
254303 # ChainOfThought()
255304
256- await cl .Message (content = answer , elements = text_elements ).send ()
305+ # await cl.Message(content=answer, elements=text_elements).send()
257306
258307
259308
@@ -295,3 +344,35 @@ async def main(message: cl.Message):
295344 # await cl.Message(content=answer).send()
296345
297346 # Approach 4 (Approach 1 + memory)
347+ response = await chain .ainvoke (
348+ {"input" : message .content },
349+ config = {"configurable" : {"session_id" : "abc123" },
350+ "callbacks" :[cl .AsyncLangchainCallbackHandler ()]},
351+ )
352+
353+ # print(response) #debugging
354+
355+ answer = response ["answer" ]
356+
357+ source_documents = response ["context" ]
358+ text_elements = []
359+ unique_pages = set ()
360+
361+ if source_documents :
362+ for source_idx , source_doc in enumerate (source_documents ):
363+ source_name = f"source_{ source_idx } "
364+ page_number = source_doc .metadata ['page' ]
365+ page = f"Page { page_number } "
366+ text_element_content = source_doc .page_content
367+ # text_elements.append(cl.Text(content=text_element_content, name=source_name))
368+ if page not in unique_pages :
369+ unique_pages .add (page )
370+ text_elements .append (cl .Text (content = text_element_content , name = page ))
371+ # text_elements.append(cl.Text(content=text_element_content, name=page))
372+ source_names = [text_el .name for text_el in text_elements ]
373+ if source_names :
374+ answer += f"\n \n Sources: { ', ' .join (source_names )} "
375+ else :
376+ answer += "\n \n No sources found"
377+
378+ await cl .Message (content = answer , elements = text_elements ).send ()
0 commit comments