11use alloy:: primitives:: { Bytes , U256 } ;
2- use dkn_workflows:: { Entry , Executor , Model , ProgramMemory , Workflow } ;
2+ use dkn_workflows:: { Entry , Executor , MessageInput , Model , ProgramMemory , Workflow } ;
33use eyre:: { eyre, Context , Result } ;
44
5- mod chat;
6- use chat:: * ;
7-
8- use super :: { postprocess:: * , presets:: GENERATION_WORKFLOW } ;
5+ use super :: { postprocess:: * , presets:: * } ;
96use crate :: {
107 bytes_to_string,
118 data:: { Arweave , OracleExternalData } ,
129 DriaOracle ,
1310} ;
1411
12+ /// A request with chat history.
13+ #[ derive( Debug , serde:: Serialize , serde:: Deserialize ) ]
14+ pub struct ChatHistoryRequest {
15+ /// Task Id of which the output will act like history.
16+ pub history_id : usize ,
17+ /// Message content.
18+ pub content : String ,
19+ }
20+
1521/// An oracle request.
1622#[ derive( Debug ) ]
1723pub enum Request {
@@ -38,7 +44,7 @@ impl Request {
3844
3945 /// Executes a request using the given model, and optionally a node.
4046 /// Returns the raw string output.
41- pub async fn execute ( & self , model : Model , node : Option < & DriaOracle > ) -> Result < String > {
47+ pub async fn execute ( & mut self , model : Model , node : Option < & DriaOracle > ) -> Result < String > {
4248 log:: debug!( "Executing workflow with: {}" , model) ;
4349 let mut memory = ProgramMemory :: new ( ) ;
4450 let executor = Executor :: new ( model) ;
@@ -57,44 +63,53 @@ impl Request {
5763 . wrap_err ( "could not execute worfklow for string input" )
5864 }
5965
60- Self :: ChatHistory ( chat_history ) => {
61- if let Some ( node ) = node {
66+ Self :: ChatHistory ( chat_request ) => {
67+ let mut history = if chat_request . history_id == 0 {
6268 // if task id is zero, there is no prior history
63- let mut history = if chat_history. history_id == 0 {
64- Vec :: new ( )
65- } else {
66- // get history from blockchain if requested
67- let history_task = node
68- . get_task_best_response ( U256 :: from ( chat_history. history_id ) )
69- . await
70- . wrap_err ( "could not get chat history task from contract" ) ?;
71-
72- // parse it as chat history output
73- let history_str = Self :: parse_downloadable ( & history_task. output ) . await ?;
74-
75- serde_json:: from_str :: < Vec < ChatHistoryResponse > > ( & history_str) ?
76- } ;
77-
78- // execute the workflow
79- // TODO: add chat history to memory
80- let entry = Entry :: String ( chat_history. content . clone ( ) ) ;
81- let output = executor
82- . execute ( Some ( & entry) , & GENERATION_WORKFLOW , & mut memory)
69+ Vec :: new ( )
70+ } else if let Some ( node) = node {
71+ // if task id is non-zero, we need the node to get the history
72+ let history_task = node
73+ . get_task_best_response ( U256 :: from ( chat_request. history_id ) )
8374 . await
84- . wrap_err ( "could not execute chat worfklow " ) ?;
75+ . wrap_err ( "could not get chat history task from contract " ) ?;
8576
86- // append user input & workflow output to chat history
87- history. push ( ChatHistoryResponse :: user ( chat_history. content . clone ( ) ) ) ;
88- history. push ( ChatHistoryResponse :: assistant ( output) ) ;
77+ // parse it as chat history output
78+ let history_str = Self :: parse_downloadable ( & history_task. output ) . await ?;
8979
90- // return the stringified output
91- let out = serde_json:: to_string ( & history)
92- . wrap_err ( "could not serialize chat history" ) ?;
93-
94- Ok ( out)
80+ serde_json:: from_str :: < Vec < MessageInput > > ( & history_str) ?
9581 } else {
96- Err ( eyre ! ( "node is required for chat history" ) )
97- }
82+ return Err ( eyre ! ( "node is required for chat history" ) ) ;
83+ } ;
84+
85+ // append user input to chat history
86+ history. push ( MessageInput {
87+ role : "user" . to_string ( ) ,
88+ content : chat_request. content . clone ( ) ,
89+ } ) ;
90+
91+ // prepare the workflow with chat history
92+ let mut workflow = get_search_workflow ( ) ;
93+ let task = workflow. get_tasks_by_id_mut ( "A" ) . unwrap ( ) ;
94+ task. messages = history. clone ( ) ;
95+
96+ // call workflow
97+ let output = executor
98+ . execute ( None , & workflow, & mut memory)
99+ . await
100+ . wrap_err ( "could not execute chat worfklow" ) ?;
101+
102+ // append user input to chat history
103+ history. push ( MessageInput {
104+ role : "assistant" . to_string ( ) ,
105+ content : output,
106+ } ) ;
107+
108+ // return the stringified output
109+ let out =
110+ serde_json:: to_string ( & history) . wrap_err ( "could not serialize chat history" ) ?;
111+
112+ Ok ( out)
98113 }
99114 }
100115 }
@@ -135,27 +150,29 @@ mod tests {
135150 use super :: * ;
136151
137152 // only implemented for testing purposes
138- // because Workflow and ChatHistory do not implement PartialEq
153+ // because ` Workflow` does not implement ` PartialEq`
139154 impl PartialEq for Request {
140155 fn eq ( & self , other : & Self ) -> bool {
141156 match ( self , other) {
142- ( Self :: ChatHistory ( _) , Self :: ChatHistory ( _) ) => true ,
143- ( Self :: Workflow ( _) , Self :: Workflow ( _) ) => true ,
157+ ( Self :: ChatHistory ( a) , Self :: ChatHistory ( b) ) => {
158+ a. content == b. content && a. history_id == b. history_id
159+ }
160+ ( Self :: Workflow ( _) , Self :: Workflow ( _) ) => true , // not implemented
144161 ( Self :: String ( a) , Self :: String ( b) ) => a == b,
145162 _ => false ,
146163 }
147164 }
148165 }
149166
150167 #[ tokio:: test]
151- async fn test_parse_input_string ( ) {
152- let input_str = "foobar" ;
153- let entry = Request :: try_parse_bytes ( & input_str . as_bytes ( ) . into ( ) ) . await ;
154- assert_eq ! ( entry. unwrap( ) , Request :: String ( input_str . into( ) ) ) ;
168+ async fn test_parse_request_string ( ) {
169+ let request_str = "foobar" ;
170+ let entry = Request :: try_parse_bytes ( & request_str . as_bytes ( ) . into ( ) ) . await ;
171+ assert_eq ! ( entry. unwrap( ) , Request :: String ( request_str . into( ) ) ) ;
155172 }
156173
157174 #[ tokio:: test]
158- async fn test_parse_input_arweave ( ) {
175+ async fn test_parse_request_arweave ( ) {
159176 // contains the string: "\"Hello, Arweave!\""
160177 // hex for: Zg6CZYfxXCWYnCuKEpnZCYfy7ghit1_v4-BCe53iWuA
161178 let arweave_key = "660e826587f15c25989c2b8a1299d90987f2ee0862b75fefe3e0427b9de25ae0" ;
@@ -166,7 +183,7 @@ mod tests {
166183 }
167184
168185 #[ tokio:: test]
169- async fn test_parse_input_workflow ( ) {
186+ async fn test_parse_request_workflow ( ) {
170187 let workflow_str = include_str ! ( "../presets/generation.json" ) ;
171188 let expected_workflow = serde_json:: from_str :: < Workflow > ( & workflow_str) . unwrap ( ) ;
172189
@@ -175,22 +192,22 @@ mod tests {
175192 }
176193
177194 #[ tokio:: test]
178- async fn test_parse_input_chat ( ) {
179- let input = ChatHistoryRequest {
195+ async fn test_parse_request_chat ( ) {
196+ let request = ChatHistoryRequest {
180197 history_id : 0 ,
181198 content : "foobar" . to_string ( ) ,
182199 } ;
183- let input_bytes = serde_json:: to_vec ( & input ) . unwrap ( ) ;
184- let entry = Request :: try_parse_bytes ( & input_bytes . into ( ) ) . await ;
185- assert_eq ! ( entry. unwrap( ) , Request :: ChatHistory ( input ) ) ;
200+ let request_bytes = serde_json:: to_vec ( & request ) . unwrap ( ) ;
201+ let entry = Request :: try_parse_bytes ( & request_bytes . into ( ) ) . await ;
202+ assert_eq ! ( entry. unwrap( ) , Request :: ChatHistory ( request ) ) ;
186203 }
187204
188205 #[ tokio:: test]
189206 #[ ignore = "run this manually" ]
190207 async fn test_ollama_generation ( ) {
191208 dotenvy:: dotenv ( ) . unwrap ( ) ;
192- let input = Request :: String ( "What is the result of 2 + 2?" . to_string ( ) ) ;
193- let output = input . execute ( Model :: Llama3_1_8B , None ) . await . unwrap ( ) ;
209+ let mut request = Request :: String ( "What is the result of 2 + 2?" . to_string ( ) ) ;
210+ let output = request . execute ( Model :: Llama3_1_8B , None ) . await . unwrap ( ) ;
194211
195212 println ! ( "Output:\n {}" , output) ;
196213 assert ! ( output. contains( '4' ) ) ;
@@ -200,8 +217,26 @@ mod tests {
200217 #[ ignore = "run this manually" ]
201218 async fn test_openai_generation ( ) {
202219 dotenvy:: dotenv ( ) . unwrap ( ) ;
203- let input = Request :: String ( "What is the result of 2 + 2?" . to_string ( ) ) ;
204- let output = input. execute ( Model :: GPT4Turbo , None ) . await . unwrap ( ) ;
220+ let mut request = Request :: String ( "What is the result of 2 + 2?" . to_string ( ) ) ;
221+ let output = request. execute ( Model :: GPT4Turbo , None ) . await . unwrap ( ) ;
222+
223+ println ! ( "Output:\n {}" , output) ;
224+ assert ! ( output. contains( '4' ) ) ;
225+ }
226+
227+ #[ tokio:: test]
228+ #[ ignore = "run this manually" ]
229+ async fn test_openai_chat ( ) {
230+ dotenvy:: dotenv ( ) . unwrap ( ) ;
231+ let request = ChatHistoryRequest {
232+ history_id : 0 ,
233+ content : "What is 2+2?" . to_string ( ) ,
234+ } ;
235+ let request_bytes = serde_json:: to_vec ( & request) . unwrap ( ) ;
236+ let mut request = Request :: try_parse_bytes ( & request_bytes. into ( ) )
237+ . await
238+ . unwrap ( ) ;
239+ let output = request. execute ( Model :: GPT4Turbo , None ) . await . unwrap ( ) ;
205240
206241 println ! ( "Output:\n {}" , output) ;
207242 assert ! ( output. contains( '4' ) ) ;
0 commit comments