Skip to content

Commit 95c4318

Browse files
authored
Add ollama provider (#31)
1 parent 2fb10d0 commit 95c4318

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

lib/rag/ai/ollama.ex

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
defmodule Rag.Ai.Ollama do
2+
@moduledoc """
3+
Implementation of `Rag.Ai.Provider` using Ollama.
4+
"""
5+
6+
@behaviour Rag.Ai.Provider
7+
8+
@type t :: %__MODULE__{
9+
embeddings_url: String.t() | nil,
10+
embeddings_model: String.t() | nil,
11+
text_url: String.t() | nil,
12+
text_model: String.t() | nil
13+
}
14+
defstruct embeddings_url: "http://localhost:11434/api/embed",
15+
embeddings_model: nil,
16+
text_url: "http://localhost:11434/api/chat",
17+
text_model: nil
18+
19+
@impl Rag.Ai.Provider
20+
def new(attrs) do
21+
struct!(__MODULE__, attrs)
22+
end
23+
24+
@impl Rag.Ai.Provider
25+
def generate_embeddings(%__MODULE__{} = provider, texts, _opts \\ []) do
26+
req_params =
27+
[
28+
json: %{"model" => provider.embeddings_model, "input" => texts}
29+
]
30+
31+
with {:ok, %Req.Response{status: 200} = response} <-
32+
Req.post(provider.embeddings_url, req_params),
33+
{:ok, embeddings} <- get_embeddings(response) do
34+
{:ok, embeddings}
35+
else
36+
{:ok, %Req.Response{status: status}} ->
37+
{:error, "HTTP request failed with status code #{status}"}
38+
39+
{:error, reason} ->
40+
{:error, reason}
41+
end
42+
end
43+
44+
defp get_embeddings(response) do
45+
path = ["embeddings"]
46+
47+
case get_in(response.body, path) do
48+
nil ->
49+
{:error,
50+
"failed to access embeddings from path embeddings in response #{inspect(response.body)}"}
51+
52+
embeddings ->
53+
{:ok, embeddings}
54+
end
55+
end
56+
57+
@impl Rag.Ai.Provider
58+
def generate_text(%__MODULE__{} = provider, prompt, _opts \\ []) do
59+
req_params =
60+
[
61+
json: %{
62+
"model" => provider.text_model,
63+
"messages" => [%{role: :user, content: prompt}],
64+
"stream" => false
65+
}
66+
]
67+
68+
with {:ok, %Req.Response{status: 200} = response} <- Req.post(provider.text_url, req_params),
69+
{:ok, text} <- get_text(response) do
70+
{:ok, text}
71+
else
72+
{:ok, %Req.Response{status: status}} ->
73+
{:error, "HTTP request failed with status code #{status}"}
74+
75+
{:error, reason} ->
76+
{:error, reason}
77+
end
78+
end
79+
80+
defp get_text(response) do
81+
path = ["message", "content"]
82+
83+
case get_in(response.body, path) do
84+
nil ->
85+
{:error, "failed to access text from path response in response #{inspect(response.body)}"}
86+
87+
text ->
88+
{:ok, text}
89+
end
90+
end
91+
end

test/rag/embedding/http_test.exs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,5 +208,13 @@ defmodule Rag.Embedding.HttpTest do
208208
[%{text: "hello", embedding: _embedding}] =
209209
Embedding.generate_embeddings_batch([%{text: "hello"}], provider, [])
210210
end
211+
212+
@tag :integration_test
213+
test "ollama embeddings" do
214+
provider = Ai.Ollama.new(%{embeddings_model: "unclemusclez/jina-embeddings-v2-base-code"})
215+
216+
assert [%{text: "hello", embedding: _embedding}] =
217+
Embedding.generate_embeddings_batch([%{text: "hello"}], provider, [])
218+
end
211219
end
212220
end

test/rag/generation/http_test.exs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,16 @@ defmodule Rag.Generation.HttpTest do
4040
%Generation{query: "test?", response: _response} =
4141
Generation.generate_response(%Generation{query: "test?", prompt: "prompt"}, provider)
4242
end
43+
44+
@tag :integration_test
45+
test "ollama generation" do
46+
provider = Ai.Ollama.new(%{text_model: "llama3.2:latest"})
47+
48+
assert %Generation{query: "test?", response: _response} =
49+
Generation.generate_response(
50+
%Generation{query: "test?", prompt: "prompt"},
51+
provider
52+
)
53+
end
4354
end
4455
end

0 commit comments

Comments
 (0)