Skip to content

Rag/rag transition 4 #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: rag/rag-transition-3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 11 additions & 26 deletions lib/chatbot/chat.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ defmodule Chatbot.Chat do
# `add_callback/2` expects a map with all possible handler functions.
# See:
# https://hexdocs.pm/langchain/0.3.0-rc.0/LangChain.Chains.ChainCallbacks.html#t:chain_callback_handler/0
@dialyzer {:nowarn_function, stream_assistant_message: 1}

@doc """
Creates a message.
Expand Down Expand Up @@ -42,8 +41,6 @@ defmodule Chatbot.Chat do
@spec request_assistant_message([Message.t()]) ::
{:ok, Message.t()} | {:error, String.t() | Ecto.Changeset.t()}
def request_assistant_message(messages) do
maybe_mock_llm()

messages = Enum.map(messages, &to_langchain_message/1)

@chain
Expand All @@ -64,15 +61,13 @@ defmodule Chatbot.Chat do

Once the full message was processed, it is saved as an assistant message.
"""
@spec stream_assistant_message(pid()) :: Message.t()
def stream_assistant_message(receiver) do
messages = all_messages() |> Enum.map(&to_langchain_message/1)

{:ok, assistant_message} = create_message(%{role: :assistant, content: ""})
@spec stream_assistant_message(pid(), [Message.t()], Message.t()) :: Message.t()
def stream_assistant_message(receiver, messages, assistant_message) do
messages = Enum.map(messages, &to_langchain_message/1)

handler = %{
on_llm_new_delta: fn _model, %LangChain.MessageDelta{} = data ->
send(receiver, {:next_message_delta, assistant_message.id, data})
send(receiver, {:next_message_delta, assistant_message, data})
end,
on_message_processed: fn _chain, %LangChain.Message{} = data ->
completed_message = update_message!(assistant_message, %{content: data.content})
Expand All @@ -81,29 +76,19 @@ defmodule Chatbot.Chat do
end
}

Task.Supervisor.start_child(Chatbot.TaskSupervisor, fn ->
maybe_mock_llm(stream: true)

@chain
|> LLMChain.add_callback(handler)
|> LLMChain.add_llm_callback(handler)
|> LLMChain.add_messages(messages)
|> LLMChain.run()
end)

assistant_message
@chain
|> LLMChain.add_callback(handler)
|> LLMChain.add_llm_callback(handler)
|> LLMChain.add_messages(messages)
|> LLMChain.run()
end

defp to_langchain_message(%{role: :user, content: content}),
def to_langchain_message(%{role: :user, content: content}),
do: LangChain.Message.new_user!(content)

defp to_langchain_message(%{role: :assistant, content: content}),
def to_langchain_message(%{role: :assistant, content: content}),
do: LangChain.Message.new_assistant!(content)

defp maybe_mock_llm(opts \\ []) do
if Application.fetch_env!(:chatbot, :mock_llm_api), do: LLMMock.mock(opts)
end

@doc """
Lists all messages ordered by insertion date.
"""
Expand Down
3 changes: 2 additions & 1 deletion lib/chatbot/chat/message.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ defmodule Chatbot.Chat.Message do
schema "messages" do
field :role, Ecto.Enum, values: @message_types
field :content, :string
field :sources, {:array, :string}

timestamps()
end
Expand All @@ -24,7 +25,7 @@ defmodule Chatbot.Chat.Message do
@spec changeset(t(), map()) :: Ecto.Changeset.t()
def changeset(message \\ %__MODULE__{}, attrs) do
message
|> cast(attrs, [:role])
|> cast(attrs, [:role, :sources])
|> cast(attrs, [:content], empty_values: [nil])
# we cannot require the content, as
# validate_required still considers "" as empty
Expand Down
13 changes: 4 additions & 9 deletions lib/chatbot/rag.ex
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ defmodule Chatbot.Rag do
Enum.map(chunks, &Map.put(ingestion, :chunk, &1.text))
end

def query(query) do
def build_generation(query) do
generation =
Generation.new(query)
|> Embedding.generate_embedding(@provider)
Expand All @@ -85,13 +85,12 @@ defmodule Chatbot.Rag do
Generation.get_retrieval_result(generation, :rrf_result)
|> Enum.map(& &1.source)

prompt = smollm_prompt(query, context)
prompt = prompt(query, context)

generation
|> Generation.put_context(context)
|> Generation.put_context_sources(context_sources)
|> Generation.put_prompt(prompt)
|> Generation.generate_response(@provider)
end

defp to_chunk(ingestion) do
Expand Down Expand Up @@ -124,19 +123,15 @@ defmodule Chatbot.Rag do
)}
end

defp smollm_prompt(query, context) do
defp prompt(query, context) do
"""
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Context information is below.
---------------------
#{context}
---------------------
Given the context information and no prior knowledge, answer the query.
Query: #{query}
Answer: <|im_end|>
<|im_start|>assist
Answer:
"""
end
end
57 changes: 48 additions & 9 deletions lib/chatbot_web/live/chat_live.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule ChatbotWeb.ChatLive do
use ChatbotWeb, :live_view
import ChatbotWeb.CoreComponents
import BitcrowdEcto.Random, only: [uuid: 0]
alias Chatbot.Chat
alias Chatbot.{Chat, Repo}

@impl Phoenix.LiveView
def mount(_params, _session, socket) do
Expand Down Expand Up @@ -36,6 +36,7 @@ defmodule ChatbotWeb.ChatLive do
id={dom_id}
role={message.role}
content={message.content}
sources={message.sources}
/>
</div>

Expand Down Expand Up @@ -76,18 +77,43 @@ defmodule ChatbotWeb.ChatLive do
~H"""
<.ui_card id={@id} class={@class}>
<%= @markdown %>

<details :if={@sources}>
<summary>Sources</summary>
<ol>
<li :for={source <- @sources}>
<%= source %>
</li>
</ol>
</details>
</.ui_card>
"""
end

@impl Phoenix.LiveView
def handle_event("send", %{"message" => %{"content" => content}}, socket) do
messages = Chat.all_messages()

pid = self()

with {:ok, user_message} <- Chat.create_message(%{role: :user, content: content}),
assistant_message <- Chat.stream_assistant_message(self()) do
{:ok, assistant_message} <- Chat.create_message(%{role: :assistant, content: ""}) do
{:noreply,
socket
|> assign(:form, build_form())
|> stream(:messages, [user_message, assistant_message])}
|> stream(:messages, [user_message, assistant_message])
|> start_async(:rag, fn ->
{:ok, augmented_user_message, augmentation} = augment_user_message(user_message)

assistant_message =
Chat.update_message!(assistant_message, %{sources: augmentation.context_sources})

Chat.stream_assistant_message(
pid,
messages ++ [augmented_user_message],
assistant_message
)
end)}
end
end

Expand All @@ -108,26 +134,31 @@ defmodule ChatbotWeb.ChatLive do
{:noreply, assign(socket, :currently_streamed_response, nil)}
end

def handle_info({:next_message_delta, id, %{status: :incomplete} = message_delta}, socket) do
def handle_info({:next_message_delta, message, %{status: :incomplete} = message_delta}, socket) do
currently_streamed_response = socket.assigns.currently_streamed_response

merged_message_deltas =
LangChain.MessageDelta.merge_delta(currently_streamed_response, message_delta)

{:noreply,
socket
|> stream_insert(:messages, %{
id: id,
role: :assistant,
content: merged_message_deltas.content
})
|> stream_insert(:messages, %{message | content: merged_message_deltas.content})
|> assign(:currently_streamed_response, merged_message_deltas)}
end

def handle_info({:message_processed, completed_message}, socket) do
{:noreply, stream_insert(socket, :messages, completed_message)}
end

def handle_info({_key, _event}, socket) do
{:noreply, socket}
end

@impl true
def handle_async(:rag, _no, socket) do
{:noreply, socket}
end

defp build_form do
%{role: :user, content: ""}
|> Chat.Message.changeset()
Expand All @@ -136,4 +167,12 @@ defmodule ChatbotWeb.ChatLive do
# for a new message and clears the input
|> to_form(id: uuid())
end

defp augment_user_message(user_message) do
%{role: :user, content: query} = user_message

rag_generation = Chatbot.Rag.build_generation(query)

{:ok, %{user_message | content: rag_generation.prompt}, rag_generation}
end
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defmodule Chatbot.Repo.Migrations.AddSourcesToMessages do
use Ecto.Migration

def change do
alter table(:messages) do
add(:sources, {:array, :string})
end
end
end
Loading