Skip to content

Commit 4c745ad

Browse files
committed
added chat history capability
1 parent 4e76d29 commit 4c745ad

File tree

8 files changed

+597
-669
lines changed

8 files changed

+597
-669
lines changed

Cargo.lock

Lines changed: 437 additions & 567 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "dkn-oracle"
33
description = "Dria Knowledge Network: Oracle Node"
4-
version = "0.1.7"
4+
version = "0.1.8"
55
edition = "2021"
66
license = "Apache-2.0"
77
readme = "README.md"

src/compute/handlers/generation.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub async fn handle_generation(
5050
// execute task
5151
log::debug!("Executing the workflow");
5252
let protocol_string = bytes32_to_string(&protocol)?;
53-
let input = Request::try_parse_bytes(&request.input).await?;
53+
let mut input = Request::try_parse_bytes(&request.input).await?;
5454
let output = input.execute(model, Some(node)).await?;
5555
log::debug!("Output: {}", output);
5656

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
{
2+
"name": "LLM generation",
3+
"description": "Directly generate text with input",
4+
"config": {
5+
"max_steps": 10,
6+
"max_time": 50,
7+
"tools": [""]
8+
},
9+
"external_memory": {
10+
"context": [""],
11+
"question": [""],
12+
"answer": [""]
13+
},
14+
"tasks": [
15+
{
16+
"id": "A",
17+
"name": "Generate with history",
18+
"description": "Expects an array of messages for generation",
19+
"messages": [],
20+
"inputs": [],
21+
"operator": "generation",
22+
"outputs": [
23+
{
24+
"type": "write",
25+
"key": "result",
26+
"value": "__result"
27+
}
28+
]
29+
},
30+
{
31+
"id": "__end",
32+
"name": "end",
33+
"description": "End of the task",
34+
"messages": [{ "role": "user", "content": "End of the task" }],
35+
"inputs": [],
36+
"operator": "end",
37+
"outputs": []
38+
}
39+
],
40+
"steps": [
41+
{
42+
"source": "A",
43+
"target": "__end"
44+
}
45+
],
46+
"return_value": {
47+
"input": {
48+
"type": "read",
49+
"key": "result"
50+
}
51+
}
52+
}

src/compute/workflows/presets/generation.json

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
{
1616
"id": "A",
1717
"name": "Generate",
18-
"description": "",
19-
"prompt": "{text}",
18+
"description": "Executes a simple generation request",
19+
"messages": [
20+
{
21+
"role": "user",
22+
"content": "{{text}}"
23+
}
24+
],
2025
"inputs": [
2126
{
2227
"name": "text",
@@ -40,7 +45,7 @@
4045
"id": "__end",
4146
"name": "end",
4247
"description": "End of the task",
43-
"prompt": "End of the task",
48+
"messages": [{ "role": "user", "content": "End of the task" }],
4449
"inputs": [],
4550
"operator": "end",
4651
"outputs": []

src/compute/workflows/presets/mod.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ use dkn_workflows::Workflow;
22
use lazy_static::lazy_static;
33

44
lazy_static! {
5-
pub static ref GENERATION_WORKFLOW: Workflow = {
5+
pub static ref GENERATION_WORKFLOW: Workflow =
66
serde_json::from_str(include_str!("generation.json"))
7-
.expect("could not parse generation workflow")
8-
};
7+
.expect("could not parse generation workflow");
8+
}
9+
10+
pub fn get_search_workflow() -> Workflow {
11+
serde_json::from_str(include_str!("chat.json")).expect("could not parse generation workflow")
912
}

src/compute/workflows/requests/chat.rs

Lines changed: 0 additions & 37 deletions
This file was deleted.

src/compute/workflows/requests/mod.rs

Lines changed: 92 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
use alloy::primitives::{Bytes, U256};
2-
use dkn_workflows::{Entry, Executor, Model, ProgramMemory, Workflow};
2+
use dkn_workflows::{Entry, Executor, MessageInput, Model, ProgramMemory, Workflow};
33
use eyre::{eyre, Context, Result};
44

5-
mod chat;
6-
use chat::*;
7-
8-
use super::{postprocess::*, presets::GENERATION_WORKFLOW};
5+
use super::{postprocess::*, presets::*};
96
use crate::{
107
bytes_to_string,
118
data::{Arweave, OracleExternalData},
129
DriaOracle,
1310
};
1411

12+
/// A request with chat history.
13+
#[derive(Debug, serde::Serialize, serde::Deserialize)]
14+
pub struct ChatHistoryRequest {
15+
/// Task Id of which the output will act like history.
16+
pub history_id: usize,
17+
/// Message content.
18+
pub content: String,
19+
}
20+
1521
/// An oracle request.
1622
#[derive(Debug)]
1723
pub enum Request {
@@ -38,7 +44,7 @@ impl Request {
3844

3945
/// Executes a request using the given model, and optionally a node.
4046
/// Returns the raw string output.
41-
pub async fn execute(&self, model: Model, node: Option<&DriaOracle>) -> Result<String> {
47+
pub async fn execute(&mut self, model: Model, node: Option<&DriaOracle>) -> Result<String> {
4248
log::debug!("Executing workflow with: {}", model);
4349
let mut memory = ProgramMemory::new();
4450
let executor = Executor::new(model);
@@ -57,44 +63,53 @@ impl Request {
5763
.wrap_err("could not execute worfklow for string input")
5864
}
5965

60-
Self::ChatHistory(chat_history) => {
61-
if let Some(node) = node {
66+
Self::ChatHistory(chat_request) => {
67+
let mut history = if chat_request.history_id == 0 {
6268
// if task id is zero, there is no prior history
63-
let mut history = if chat_history.history_id == 0 {
64-
Vec::new()
65-
} else {
66-
// get history from blockchain if requested
67-
let history_task = node
68-
.get_task_best_response(U256::from(chat_history.history_id))
69-
.await
70-
.wrap_err("could not get chat history task from contract")?;
71-
72-
// parse it as chat history output
73-
let history_str = Self::parse_downloadable(&history_task.output).await?;
74-
75-
serde_json::from_str::<Vec<ChatHistoryResponse>>(&history_str)?
76-
};
77-
78-
// execute the workflow
79-
// TODO: add chat history to memory
80-
let entry = Entry::String(chat_history.content.clone());
81-
let output = executor
82-
.execute(Some(&entry), &GENERATION_WORKFLOW, &mut memory)
69+
Vec::new()
70+
} else if let Some(node) = node {
71+
// if task id is non-zero, we need the node to get the history
72+
let history_task = node
73+
.get_task_best_response(U256::from(chat_request.history_id))
8374
.await
84-
.wrap_err("could not execute chat worfklow")?;
75+
.wrap_err("could not get chat history task from contract")?;
8576

86-
// append user input & workflow output to chat history
87-
history.push(ChatHistoryResponse::user(chat_history.content.clone()));
88-
history.push(ChatHistoryResponse::assistant(output));
77+
// parse it as chat history output
78+
let history_str = Self::parse_downloadable(&history_task.output).await?;
8979

90-
// return the stringified output
91-
let out = serde_json::to_string(&history)
92-
.wrap_err("could not serialize chat history")?;
93-
94-
Ok(out)
80+
serde_json::from_str::<Vec<MessageInput>>(&history_str)?
9581
} else {
96-
Err(eyre!("node is required for chat history"))
97-
}
82+
return Err(eyre!("node is required for chat history"));
83+
};
84+
85+
// append user input to chat history
86+
history.push(MessageInput {
87+
role: "user".to_string(),
88+
content: chat_request.content.clone(),
89+
});
90+
91+
// prepare the workflow with chat history
92+
let mut workflow = get_search_workflow();
93+
let task = workflow.get_tasks_by_id_mut("A").unwrap();
94+
task.messages = history.clone();
95+
96+
// call workflow
97+
let output = executor
98+
.execute(None, &workflow, &mut memory)
99+
.await
100+
.wrap_err("could not execute chat worfklow")?;
101+
102+
// append user input to chat history
103+
history.push(MessageInput {
104+
role: "assistant".to_string(),
105+
content: output,
106+
});
107+
108+
// return the stringified output
109+
let out =
110+
serde_json::to_string(&history).wrap_err("could not serialize chat history")?;
111+
112+
Ok(out)
98113
}
99114
}
100115
}
@@ -135,27 +150,29 @@ mod tests {
135150
use super::*;
136151

137152
// only implemented for testing purposes
138-
// because Workflow and ChatHistory do not implement PartialEq
153+
// because `Workflow` does not implement `PartialEq`
139154
impl PartialEq for Request {
140155
fn eq(&self, other: &Self) -> bool {
141156
match (self, other) {
142-
(Self::ChatHistory(_), Self::ChatHistory(_)) => true,
143-
(Self::Workflow(_), Self::Workflow(_)) => true,
157+
(Self::ChatHistory(a), Self::ChatHistory(b)) => {
158+
a.content == b.content && a.history_id == b.history_id
159+
}
160+
(Self::Workflow(_), Self::Workflow(_)) => true, // not implemented
144161
(Self::String(a), Self::String(b)) => a == b,
145162
_ => false,
146163
}
147164
}
148165
}
149166

150167
#[tokio::test]
151-
async fn test_parse_input_string() {
152-
let input_str = "foobar";
153-
let entry = Request::try_parse_bytes(&input_str.as_bytes().into()).await;
154-
assert_eq!(entry.unwrap(), Request::String(input_str.into()));
168+
async fn test_parse_request_string() {
169+
let request_str = "foobar";
170+
let entry = Request::try_parse_bytes(&request_str.as_bytes().into()).await;
171+
assert_eq!(entry.unwrap(), Request::String(request_str.into()));
155172
}
156173

157174
#[tokio::test]
158-
async fn test_parse_input_arweave() {
175+
async fn test_parse_request_arweave() {
159176
// contains the string: "\"Hello, Arweave!\""
160177
// hex for: Zg6CZYfxXCWYnCuKEpnZCYfy7ghit1_v4-BCe53iWuA
161178
let arweave_key = "660e826587f15c25989c2b8a1299d90987f2ee0862b75fefe3e0427b9de25ae0";
@@ -166,7 +183,7 @@ mod tests {
166183
}
167184

168185
#[tokio::test]
169-
async fn test_parse_input_workflow() {
186+
async fn test_parse_request_workflow() {
170187
let workflow_str = include_str!("../presets/generation.json");
171188
let expected_workflow = serde_json::from_str::<Workflow>(&workflow_str).unwrap();
172189

@@ -175,22 +192,22 @@ mod tests {
175192
}
176193

177194
#[tokio::test]
178-
async fn test_parse_input_chat() {
179-
let input = ChatHistoryRequest {
195+
async fn test_parse_request_chat() {
196+
let request = ChatHistoryRequest {
180197
history_id: 0,
181198
content: "foobar".to_string(),
182199
};
183-
let input_bytes = serde_json::to_vec(&input).unwrap();
184-
let entry = Request::try_parse_bytes(&input_bytes.into()).await;
185-
assert_eq!(entry.unwrap(), Request::ChatHistory(input));
200+
let request_bytes = serde_json::to_vec(&request).unwrap();
201+
let entry = Request::try_parse_bytes(&request_bytes.into()).await;
202+
assert_eq!(entry.unwrap(), Request::ChatHistory(request));
186203
}
187204

188205
#[tokio::test]
189206
#[ignore = "run this manually"]
190207
async fn test_ollama_generation() {
191208
dotenvy::dotenv().unwrap();
192-
let input = Request::String("What is the result of 2 + 2?".to_string());
193-
let output = input.execute(Model::Llama3_1_8B, None).await.unwrap();
209+
let mut request = Request::String("What is the result of 2 + 2?".to_string());
210+
let output = request.execute(Model::Llama3_1_8B, None).await.unwrap();
194211

195212
println!("Output:\n{}", output);
196213
assert!(output.contains('4'));
@@ -200,8 +217,26 @@ mod tests {
200217
#[ignore = "run this manually"]
201218
async fn test_openai_generation() {
202219
dotenvy::dotenv().unwrap();
203-
let input = Request::String("What is the result of 2 + 2?".to_string());
204-
let output = input.execute(Model::GPT4Turbo, None).await.unwrap();
220+
let mut request = Request::String("What is the result of 2 + 2?".to_string());
221+
let output = request.execute(Model::GPT4Turbo, None).await.unwrap();
222+
223+
println!("Output:\n{}", output);
224+
assert!(output.contains('4'));
225+
}
226+
227+
#[tokio::test]
228+
#[ignore = "run this manually"]
229+
async fn test_openai_chat() {
230+
dotenvy::dotenv().unwrap();
231+
let request = ChatHistoryRequest {
232+
history_id: 0,
233+
content: "What is 2+2?".to_string(),
234+
};
235+
let request_bytes = serde_json::to_vec(&request).unwrap();
236+
let mut request = Request::try_parse_bytes(&request_bytes.into())
237+
.await
238+
.unwrap();
239+
let output = request.execute(Model::GPT4Turbo, None).await.unwrap();
205240

206241
println!("Output:\n{}", output);
207242
assert!(output.contains('4'));

0 commit comments

Comments
 (0)