diff --git a/Cargo.lock b/Cargo.lock index 2b3aa63..ca5b780 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -393,6 +393,29 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "env_filter" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b35839ba51819680ba087cd351788c9a3c476841207e0b8cee0b04722343b9" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -699,6 +722,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91f255a4535024abf7640cb288260811fc14794f62b063652ed349f9a6c2348e" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "hyper" version = "1.3.1" @@ -930,9 +959,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lopdf" @@ -983,15 +1012,15 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mini-rag" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71facc5e6b5af46f4df9cf6ad32828e2d871420ddc94d73be7c41abbfb3fdcda" +checksum = "c352137db900aafef20d87aaf54e96beffeb28faa3d984b7be8f349da50e3e32" dependencies = [ "anyhow", "async-trait", "bitcode", - "colored", "glob", + "log", "lopdf", "rayon", "serde", @@ -1053,6 +1082,7 @@ dependencies = [ "clap", "colored", "duration-string", + "env_logger", "glob", "groq-api-rs", "human_bytes", @@ -1061,6 +1091,7 @@ dependencies = [ "itertools", "lazy_static", "libc", + "log", "memory-stats", "mini-rag", "ollama-rs", diff --git a/Cargo.toml b/Cargo.toml index bb8bf41..e52e000 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,9 @@ memory-stats = "1.1.0" sha256 = "1.5.0" bitcode = { version = "0.6.0", features = ["serde"] } intertrait = "0.2.2" -mini-rag = "0.2.1" +mini-rag = "0.2.2" +env_logger = "0.11.3" +log = "0.4.22" [features] default = ["ollama", "groq", "openai", "fireworks"] diff --git a/src/agent/events/channel.rs b/src/agent/events/channel.rs new file mode 100644 index 0000000..09a6ea3 --- /dev/null +++ b/src/agent/events/channel.rs @@ -0,0 +1,6 @@ +pub(crate) type Sender = tokio::sync::mpsc::UnboundedSender; +pub(crate) type Receiver = tokio::sync::mpsc::UnboundedReceiver; + +pub(crate) fn create_channel() -> (Sender, Receiver) { + tokio::sync::mpsc::unbounded_channel() +} diff --git a/src/agent/events/mod.rs b/src/agent/events/mod.rs new file mode 100644 index 0000000..0ed4b66 --- /dev/null +++ b/src/agent/events/mod.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; + +mod channel; + +pub(crate) use channel::*; + +use super::{ + generator::Options, + state::{metrics::Metrics, storage::StorageType}, + Invocation, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) enum Event { + MetricsUpdate(Metrics), + StorageUpdate { + storage_name: String, + storage_type: StorageType, + key: String, + prev: Option, + new: Option, + }, + StateUpdate(Options), + EmptyResponse, + InvalidResponse(String), + InvalidAction { + invocation: Invocation, + error: Option, + }, + ActionTimeout { + invocation: Invocation, + elapsed: std::time::Duration, + }, + ActionExecuted { + invocation: Invocation, + error: Option, + result: Option, + elapsed: std::time::Duration, + }, + TaskComplete { + impossible: bool, + reason: Option, + }, +} diff --git a/src/agent/generator/mod.rs b/src/agent/generator/mod.rs index c69eb3c..c6f8d3e 100644 --- a/src/agent/generator/mod.rs +++ b/src/agent/generator/mod.rs @@ -2,10 +2,10 @@ use std::{fmt::Display, time::Duration}; use anyhow::Result; use async_trait::async_trait; -use colored::Colorize; use duration_string::DurationString; use lazy_static::lazy_static; use regex::Regex; +use serde::{Deserialize, Serialize}; use super::Invocation; @@ -24,7 +24,7 @@ lazy_static! { static ref CONN_RESET_PARSER: Regex = Regex::new(r"(?m)^.+onnection reset by peer.*").unwrap(); } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Options { pub system_prompt: String, pub prompt: String, @@ -41,7 +41,7 @@ impl Options { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum Message { Agent(String, Option), Feedback(String, Option), @@ -86,9 +86,8 @@ pub trait Client: mini_rag::Embedder + Send + Sync { } if let Ok(retry_time) = retry_time_str.parse::() { - println!( - "{}: rate limit reached for this model, retrying in {} ...\n", - "WARNING".bold().yellow(), + log::warn!( + "rate limit reached for this model, retrying in {} ...", retry_time, ); @@ -99,16 +98,15 @@ pub trait Client: mini_rag::Embedder + Send + Sync { return true; } else { - eprintln!("can't parse '{}'", &retry_time_str); + log::error!("can't parse '{}'", &retry_time_str); } } else { - eprintln!("cap len wrong"); + log::error!("cap len wrong"); } } else if CONN_RESET_PARSER.captures_iter(error).next().is_some() { let retry_time = Duration::from_secs(5); - println!( - "{}: connection reset by peer, retrying in {:?} ...\n", - "WARNING".bold().yellow(), + log::warn!( + "connection reset by peer, retrying in {:?} ...", &retry_time, ); diff --git a/src/agent/generator/ollama.rs b/src/agent/generator/ollama.rs index d371984..ed2c6f1 100644 --- a/src/agent/generator/ollama.rs +++ b/src/agent/generator/ollama.rs @@ -82,7 +82,7 @@ impl Client for OllamaClient { if let Some(msg) = res.message { Ok(msg.content) } else { - println!("WARNING: model returned an empty message."); + log::warn!("model returned an empty message."); Ok("".to_string()) } } diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 1a05fa2..eba980f 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,27 +1,36 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use anyhow::Result; -use colored::Colorize; +use mini_rag::Embedder; +use serde::{Deserialize, Serialize}; +use events::Event; use generator::{Client, Options}; -use mini_rag::Embedder; use namespaces::Action; +use serialization::xml::serialize; use state::{SharedState, State}; use task::Task; +pub mod events; pub mod generator; pub mod namespaces; pub mod serialization; pub mod state; pub mod task; -#[derive(Debug, Default, Clone, PartialEq, Eq)] +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Invocation { pub action: String, pub attributes: Option>, pub payload: Option, } +impl std::fmt::Display for Invocation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", serialize::invocation(self)) + } +} + impl std::hash::Hash for Invocation { fn hash(&self, state: &mut H) { self.action.hash(state); @@ -44,39 +53,32 @@ impl Invocation { } } -#[derive(Debug, Clone)] -pub struct AgentOptions { - pub max_iterations: usize, - pub save_to: Option, - pub full_dump: bool, - pub with_stats: bool, -} - pub struct Agent { + events_chan: events::Sender, generator: Box, state: SharedState, - options: AgentOptions, max_history: u16, task_timeout: Option, } impl Agent { pub async fn new( + events_chan: events::Sender, generator: Box, embedder: Box, task: Box, - options: AgentOptions, + max_iterations: usize, ) -> Result { let max_history = task.max_history_visibility(); let task_timeout = task.get_timeout(); let state = Arc::new(tokio::sync::Mutex::new( - State::new(task, embedder, options.max_iterations).await?, + State::new(events_chan.clone(), task, embedder, max_iterations).await?, )); Ok(Self { + events_chan, generator, state, - options, max_history, task_timeout, }) @@ -147,66 +149,38 @@ impl Agent { self.state.lock().await.is_complete() } - async fn save_if_needed(&self, options: &Options, refresh: bool) -> Result<()> { - if let Some(prompt_path) = &self.options.save_to { - let mut opts = options.clone(); - if refresh { - opts.system_prompt = - serialization::state_to_system_prompt(&*self.state.lock().await)?; - opts.history = self - .state - .lock() - .await - .to_chat_history(self.max_history as usize)?; - } - - let data = if self.options.full_dump { - format!( - "[SYSTEM PROMPT]\n\n{}\n\n[PROMPT]\n\n{}\n\n[CHAT]\n\n{}", - &options.system_prompt, - &options.prompt, - options - .history - .iter() - .map(|m| m.to_string()) - .collect::>() - .join("\n") - ) - } else { - opts.system_prompt.to_string() - }; - - std::fs::write(prompt_path, data)?; + async fn on_state_update(&self, options: &Options, refresh: bool) -> Result<()> { + let mut opts = options.clone(); + if refresh { + opts.system_prompt = serialization::state_to_system_prompt(&*self.state.lock().await)?; + opts.history = self + .state + .lock() + .await + .to_chat_history(self.max_history as usize)?; } - Ok(()) + self.on_event(events::Event::StateUpdate(opts)) } async fn on_empty_response(&self) { - println!( - "{}: agent did not provide valid instructions: empty response", - "WARNING".bold().red(), - ); - let mut mut_state = self.state.lock().await; mut_state.metrics.errors.empty_responses += 1; mut_state .add_unparsed_response_to_history("", "Do not return an empty responses.".to_string()); + + self.on_event(events::Event::EmptyResponse).unwrap(); } async fn on_invalid_response(&self, response: &str) { - println!( - "{}: agent did not provide valid instructions: \n\n{}\n\n", - "WARNING".bold().red(), - response.dimmed() - ); - let mut mut_state = self.state.lock().await; mut_state.metrics.errors.unparsed_responses += 1; mut_state.add_unparsed_response_to_history( response, "I could not parse any valid actions from your response, please correct it according to the instructions.".to_string(), ); + self.on_event(events::Event::InvalidResponse(response.to_string())) + .unwrap(); } async fn on_valid_response(&self) { @@ -220,9 +194,14 @@ impl Agent { let name = invocation.action.clone(); mut_state.add_error_to_history( - invocation, - error.unwrap_or(format!("'{name}' is not a valid action name")), + invocation.clone(), + error + .clone() + .unwrap_or(format!("'{name}' is not a valid action name")), ); + + self.on_event(events::Event::InvalidAction { invocation, error }) + .unwrap(); } async fn on_valid_action(&self) { @@ -230,29 +209,51 @@ impl Agent { } async fn on_timed_out_action(&self, invocation: Invocation, start: &std::time::Instant) { - println!( - "{}: action '{}' timed out after {:?}", - "WARNING".bold().yellow(), - &invocation.action, - start.elapsed() - ); let mut mut_state = self.state.lock().await; mut_state.metrics.errors.timedout_actions += 1; // tell the model about the timeout - mut_state.add_error_to_history(invocation, "action timed out".to_string()); + mut_state.add_error_to_history(invocation.clone(), "action timed out".to_string()); + + self.events_chan + .send(events::Event::ActionTimeout { + invocation, + elapsed: start.elapsed(), + }) + .unwrap(); } - async fn on_executed_action(&self, invocation: Invocation, ret: Result>) { + async fn on_executed_action( + &self, + invocation: Invocation, + ret: Result>, + start: &std::time::Instant, + ) { let mut mut_state = self.state.lock().await; - if let Err(error) = ret { + let mut error = None; + let mut result = None; + + if let Err(err) = ret { mut_state.metrics.errors.errored_actions += 1; // tell the model about the error - mut_state.add_error_to_history(invocation, error.to_string()); + mut_state.add_error_to_history(invocation.clone(), err.to_string()); + + error = Some(err.to_string()); } else { + let ret = ret.unwrap(); mut_state.metrics.success_actions += 1; // tell the model about the output - mut_state.add_success_to_history(invocation, ret.unwrap()); + mut_state.add_success_to_history(invocation.clone(), ret.clone()); + + result = ret; } + + self.on_event(events::Event::ActionExecuted { + invocation, + result, + error, + elapsed: start.elapsed(), + }) + .unwrap(); } pub async fn get_metrics(&self) -> state::metrics::Metrics { @@ -264,9 +265,7 @@ impl Agent { mut_state.on_step()?; - if self.options.with_stats { - println!("\n{}\n", &mut_state.metrics); - } + self.on_event(events::Event::MetricsUpdate(mut_state.metrics.clone()))?; let system_prompt = serialization::state_to_system_prompt(&mut_state)?; let prompt = mut_state.to_prompt()?; @@ -276,10 +275,14 @@ impl Agent { Ok(options) } + pub fn on_event(&self, event: Event) -> Result<()> { + self.events_chan.send(event).map_err(|e| anyhow!(e)) + } + pub async fn step(&mut self) -> Result<()> { let options = self.prepare_step().await?; - self.save_if_needed(&options, false).await?; + self.on_state_update(&options, false).await?; // run model inference let response = self.generator.chat(&options).await?.trim().to_string(); @@ -323,8 +326,6 @@ impl Agent { Duration::from_secs(60 * 60 * 24 * 30) }; - // println!("{} timeout={:?}", action.name(), &timeout); - // execute with timeout let start = std::time::Instant::now(); let ret = tokio::time::timeout( @@ -340,12 +341,12 @@ impl Agent { if ret.is_err() { self.on_timed_out_action(inv, &start).await; } else { - self.on_executed_action(inv, ret.unwrap()).await; + self.on_executed_action(inv, ret.unwrap(), &start).await; } } } - self.save_if_needed(&options, true).await?; + self.on_state_update(&options, true).await?; // break the loop if we're done if self.state.lock().await.is_complete() { diff --git a/src/agent/namespaces/filesystem/mod.rs b/src/agent/namespaces/filesystem/mod.rs index de18aaa..d7de0fb 100644 --- a/src/agent/namespaces/filesystem/mod.rs +++ b/src/agent/namespaces/filesystem/mod.rs @@ -7,7 +7,6 @@ use chrono::{DateTime, Local}; use libc::{S_IRGRP, S_IROTH, S_IRUSR, S_IWGRP, S_IWOTH, S_IWUSR, S_IXGRP, S_IXOTH, S_IXUSR}; use anyhow::Result; -use colored::Colorize; use super::{Action, Namespace}; use crate::agent::state::SharedState; @@ -102,20 +101,12 @@ impl Action for ReadFolder { full_path.display() ); } else { - eprintln!("ERROR: {:?}", path); + log::error!("{:?}", path); } } - println!( - "<{}> {} -> {} bytes", - self.name().bold(), - folder.yellow(), - output.len() - ); - Ok(Some(output)) } else { - eprintln!("<{}> {} -> {:?}", self.name().bold(), folder.red(), &ret); Err(anyhow!("can't read {}: {:?}", folder, ret)) } } @@ -145,24 +136,11 @@ impl Action for ReadFile { payload: Option, ) -> Result> { let filepath = payload.unwrap(); - let ret = std::fs::read_to_string(&filepath); + let ret = std::fs::read_to_string(filepath); if let Ok(contents) = ret { - println!( - "<{}> {} -> {} bytes", - self.name().bold(), - filepath.yellow(), - contents.len() - ); Ok(Some(contents)) } else { let err = ret.err().unwrap(); - println!( - "<{}> {} -> {:?}", - self.name().bold(), - filepath.yellow(), - &err - ); - Err(anyhow!(err)) } } diff --git a/src/agent/namespaces/goal/mod.rs b/src/agent/namespaces/goal/mod.rs index 954ee8a..12ba338 100644 --- a/src/agent/namespaces/goal/mod.rs +++ b/src/agent/namespaces/goal/mod.rs @@ -33,7 +33,7 @@ impl Action for UpdateGoal { .lock() .await .get_storage_mut("goal")? - .set_current(payload.as_ref().unwrap(), true); + .set_current(payload.as_ref().unwrap()); Ok(Some("goal updated".to_string())) } } diff --git a/src/agent/namespaces/memory/mod.rs b/src/agent/namespaces/memory/mod.rs index d5d967a..c1cfb73 100644 --- a/src/agent/namespaces/memory/mod.rs +++ b/src/agent/namespaces/memory/mod.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use anyhow::Result; use async_trait::async_trait; -use colored::Colorize; use super::{Action, Namespace, StorageDescriptor}; use crate::agent::state::SharedState; @@ -125,10 +124,8 @@ impl Action for RecallMemory { let key = attrs.get("key").unwrap(); if let Some(memory) = state.lock().await.get_storage("memories")?.get_tagged(key) { - println!("<{}> recalling {}", "memories".bold(), key); Ok(Some(memory)) } else { - eprintln!("<{}> memory {} does not exist", "memories".bold(), key); Err(anyhow!("memory '{}' not found", key)) } } diff --git a/src/agent/namespaces/rag/mod.rs b/src/agent/namespaces/rag/mod.rs index cd135a7..cb79371 100644 --- a/src/agent/namespaces/rag/mod.rs +++ b/src/agent/namespaces/rag/mod.rs @@ -2,7 +2,6 @@ use std::{collections::HashMap, time::Instant}; use anyhow::Result; use async_trait::async_trait; -use colored::Colorize; use crate::agent::state::SharedState; @@ -37,11 +36,12 @@ impl Action for Search { let mut docs = state.lock().await.rag_query(&query, 1).await?; if !docs.is_empty() { - println!("\n {} results in {:?}", docs.len(), start.elapsed()); - for (doc, score) in &docs { - println!(" * {} ({})", doc.get_path(), score); - } - println!(); + log::info!( + "rag search for '{}': {} results in {:?}", + query, + docs.len(), + start.elapsed() + ); Ok(Some(format!( "Here is some supporting information:\n\n{}", @@ -51,9 +51,8 @@ impl Action for Search { .join("\n") ))) } else { - println!( - "[{}] no results for '{query}' in {:?}", - "rag".bold(), + log::info!( + "search: no results for query '{query}' in {:?}", start.elapsed() ); Ok(Some("no documents for this query".to_string())) diff --git a/src/agent/serialization/xml/parsing.rs b/src/agent/serialization/xml/parsing.rs index 2952583..d51ca59 100644 --- a/src/agent/serialization/xml/parsing.rs +++ b/src/agent/serialization/xml/parsing.rs @@ -107,7 +107,7 @@ fn try_parse_block(ptr: &str) -> Parsed { loop { let event = parser.next(); if let Ok(event) = event { - // println!("{:?}", &event); + log::debug!("{:?}", &event); match event { xml::reader::XmlEvent::StartDocument { version: _, @@ -136,12 +136,12 @@ fn try_parse_block(ptr: &str) -> Parsed { if let Ok(inv) = ret { parsed.invocations.push(inv); } else { - eprintln!("WARNING: {:?}", ret.err().unwrap()); + log::error!("{:?}", ret.err().unwrap()); } break; } _ => { - eprintln!("WARNING: unexpected xml element: {:?}", event); + log::error!("unexpected xml element: {:?}", event); } } } else { diff --git a/src/agent/state/metrics.rs b/src/agent/state/metrics.rs index cc4313c..028a77a 100644 --- a/src/agent/state/metrics.rs +++ b/src/agent/state/metrics.rs @@ -1,9 +1,9 @@ use std::fmt::Display; -use colored::Colorize; use memory_stats::memory_stats; +use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct ErrorMetrics { pub empty_responses: usize, pub unparsed_responses: usize, @@ -23,7 +23,7 @@ impl ErrorMetrics { } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct Metrics { pub max_steps: usize, pub current_step: usize, @@ -35,7 +35,7 @@ pub struct Metrics { impl Display for Metrics { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[{}] steps:", "statistics".bold().blue())?; + write!(f, "step:")?; if self.max_steps > 0 { write!(f, "{}/{} ", self.current_step, self.max_steps)?; } else { diff --git a/src/agent/state/mod.rs b/src/agent/state/mod.rs index 724cce7..3a0a1fe 100644 --- a/src/agent/state/mod.rs +++ b/src/agent/state/mod.rs @@ -1,10 +1,10 @@ use std::{collections::HashMap, sync::Arc}; use anyhow::Result; -use colored::Colorize; use metrics::Metrics; use super::{ + events::Event, generator::Message, namespaces::{self, Namespace}, task::Task, @@ -30,6 +30,8 @@ pub struct State { rag: Option, // set to true when task is complete complete: bool, + // events channel + events_tx: super::events::Sender, // runtime metrics pub metrics: Metrics, } @@ -38,6 +40,7 @@ pub type SharedState = Arc>; impl State { pub async fn new( + events_tx: super::events::Sender, task: Box, embedder: Box, max_iterations: usize, @@ -106,19 +109,17 @@ impl State { if !storages.contains_key(&storage.name) { storages.insert( storage.name.to_string(), - Storage::new(&storage.name, storage.type_), + Storage::new(&storage.name, storage.type_, events_tx.clone()), ); } } } } - // println!("storages={:?}", &storages); - // if the goal namespace is enabled, set the current goal if let Some(goal) = storages.get_mut("goal") { let prompt = task.to_prompt()?; - goal.set_current(&prompt, false); + goal.set_current(&prompt); } let metrics = Metrics { @@ -134,6 +135,7 @@ impl State { complete, metrics, rag, + events_tx, }) } @@ -175,7 +177,6 @@ impl State { if let Some(storage) = self.storages.get(name) { Ok(storage) } else { - println!("WARNING: requested storage '{name}' not found."); Err(anyhow!("storage {name} not found")) } } @@ -184,7 +185,6 @@ impl State { if let Some(storage) = self.storages.get_mut(name) { Ok(storage) } else { - println!("WARNING: requested storage '{name}' not found."); Err(anyhow!("storage {name} not found")) } } @@ -207,7 +207,6 @@ impl State { } pub fn add_error_to_history(&mut self, invocation: Invocation, error: String) { - // eprintln!("[{}] -> {}", &invocation.action, error.red()); self.history.push(Execution::with_error(invocation, error)); } @@ -229,30 +228,11 @@ impl State { } pub fn on_complete(&mut self, impossible: bool, reason: Option) -> Result<()> { - // TODO: unify logging logic - if impossible { - println!( - "\n{}: '{}'", - "task is impossible".bold().red(), - if let Some(r) = &reason { - r - } else { - "no reason provided" - } - ); - } else { - println!( - "\n{}: '{}'", - "task complete".bold().green(), - if let Some(r) = &reason { - r - } else { - "no reason provided" - } - ); - } - self.complete = true; - Ok(()) + self.on_event(Event::TaskComplete { impossible, reason }) + } + + pub fn on_event(&self, event: Event) -> Result<()> { + self.events_tx.send(event).map_err(|e| anyhow!(e)) } } diff --git a/src/agent/state/storage.rs b/src/agent/state/storage.rs index 92e0388..654d6c0 100644 --- a/src/agent/state/storage.rs +++ b/src/agent/state/storage.rs @@ -1,7 +1,9 @@ use std::{ops::Deref /* , time::SystemTime*/}; -use colored::Colorize; use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; + +use crate::agent::events::{Event, Sender}; #[derive(Debug)] pub struct Entry { @@ -22,7 +24,7 @@ impl Entry { } #[allow(dead_code)] -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum StorageType { // a list indexed by element position Untagged, @@ -50,6 +52,7 @@ pub(crate) const PREVIOUS_TAG: &str = "__previous"; #[derive(Debug)] pub struct Storage { + events_tx: Sender, name: String, type_: StorageType, inner: IndexMap, @@ -65,10 +68,15 @@ impl Deref for Storage { #[allow(dead_code)] impl Storage { - pub fn new(name: &str, type_: StorageType) -> Self { + pub fn new(name: &str, type_: StorageType, events_tx: Sender) -> Self { let name = name.to_string(); let inner = IndexMap::new(); - Self { name, type_, inner } + Self { + name, + type_, + inner, + events_tx, + } } pub fn get_name(&self) -> &str { @@ -79,17 +87,36 @@ impl Storage { &self.type_ } + fn on_event(&self, event: Event) { + self.events_tx.send(event).unwrap(); + } + pub fn add_tagged(&mut self, key: &str, data: &str) { assert!(matches!(self.type_, StorageType::Tagged)); - println!("<{}> {}={}", self.name.bold(), key, data.yellow()); + self.inner .insert(key.to_string(), Entry::new(data.to_string())); + + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: key.to_string(), + prev: None, + new: Some(data.to_string()), + }); } pub fn del_tagged(&mut self, key: &str) -> Option { assert!(matches!(self.type_, StorageType::Tagged)); if let Some(old) = self.inner.shift_remove(key) { - println!("<{}> {} removed\n", self.name.bold(), key); + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: key.to_string(), + prev: Some(old.data.to_string()), + new: None, + }); + Some(old.data) } else { None @@ -103,16 +130,31 @@ impl Storage { pub fn add_completion(&mut self, data: &str) { assert!(matches!(self.type_, StorageType::Completion)); - println!("<{}> {}", self.name.bold(), data.yellow()); let tag = format!("{}", self.inner.len() + 1); - self.inner.insert(tag, Entry::new(data.to_string())); + self.inner + .insert(tag.to_string(), Entry::new(data.to_string())); + + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: tag, + prev: None, + new: Some(data.to_string()), + }); } pub fn del_completion(&mut self, pos: usize) -> Option { assert!(matches!(self.type_, StorageType::Completion)); let tag = format!("{}", pos); if let Some(old) = self.inner.shift_remove(&tag) { - println!("<{}> element {} removed\n", self.name.bold(), pos); + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: tag, + prev: Some(old.data.to_string()), + new: None, + }); + Some(old.data) } else { None @@ -123,9 +165,17 @@ impl Storage { assert!(matches!(self.type_, StorageType::Completion)); let tag = format!("{}", pos); if let Some(entry) = self.inner.get_mut(&tag) { - println!("<{}> element {} set as complete\n", self.name.bold(), pos); let prev = entry.complete; entry.complete = true; + + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: tag, + prev: Some((if prev { "complete" } else { "incomplete" }).to_string()), + new: Some("complete".to_string()), + }); + Some(prev) } else { None @@ -136,9 +186,17 @@ impl Storage { assert!(matches!(self.type_, StorageType::Completion)); let tag = format!("{}", pos); if let Some(entry) = self.inner.get_mut(&tag) { - println!("<{}> element {} set as incomplete\n", self.name.bold(), pos); let prev = entry.complete; entry.complete = false; + + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: tag, + prev: Some((if prev { "complete" } else { "incomplete" }).to_string()), + new: Some("incomplete".to_string()), + }); + Some(prev) } else { None @@ -147,39 +205,67 @@ impl Storage { pub fn add_untagged(&mut self, data: &str) { assert!(matches!(self.type_, StorageType::Untagged)); - println!("<{}> {}", self.name.bold(), data.yellow()); let tag = format!("{}", self.inner.len() + 1); - self.inner.insert(tag, Entry::new(data.to_string())); + self.inner + .insert(tag.to_string(), Entry::new(data.to_string())); + + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: tag, + prev: None, + new: Some(data.to_string()), + }); } pub fn del_untagged(&mut self, pos: usize) -> Option { assert!(matches!(self.type_, StorageType::Untagged)); let tag = format!("{}", pos); if let Some(old) = self.inner.shift_remove(&tag) { - println!("<{}> element {} removed\n", self.name.bold(), pos); + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: tag, + prev: Some(old.data.to_string()), + new: None, + }); Some(old.data) } else { None } } - pub fn set_current(&mut self, data: &str, verbose: bool) { + pub fn set_current(&mut self, data: &str) { assert!(matches!(self.type_, StorageType::CurrentPrevious)); - if verbose { - println!("<{}> current={}", self.name.bold(), data.yellow()); - } let old_current = self.inner.shift_remove(CURRENT_TAG); - self.inner .insert(CURRENT_TAG.to_string(), Entry::new(data.to_string())); - if let Some(old_curr) = old_current { + let prev = if let Some(old_curr) = old_current { + let data = old_curr.data.to_string(); self.inner.insert(PREVIOUS_TAG.to_string(), old_curr); - } + Some(data) + } else { + None + }; + + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: CURRENT_TAG.to_string(), + prev, + new: Some(data.to_string()), + }); } pub fn clear(&mut self) { self.inner.clear(); - println!("<{}> cleared", self.name.bold()); + self.on_event(Event::StorageUpdate { + storage_name: self.name.to_string(), + storage_type: self.type_, + key: "".to_string(), + prev: None, + new: None, + }); } } diff --git a/src/agent/task/tasklet.rs b/src/agent/task/tasklet.rs index 582b873..37cc6a8 100644 --- a/src/agent/task/tasklet.rs +++ b/src/agent/task/tasklet.rs @@ -67,11 +67,7 @@ impl Action for TaskletAction { if let Ok(tm) = timeout.parse::() { return Some(*tm); } else { - eprintln!( - "{}: can't parse '{}' as duration string", - "WARNING".yellow(), - timeout - ); + log::error!("can't parse '{}' as duration string", timeout); } } None @@ -118,11 +114,10 @@ impl Action for TaskletAction { } if let Some(payload) = &payload { - // println!("# {}", payload.bold()); cmd.arg(payload); } - println!( + log::info!( "{}{}{}", self.name.bold(), if payload.is_some() { @@ -146,7 +141,7 @@ impl Action for TaskletAction { }, ); - // println!("! {:?}", &cmd); + log::debug!("! {:?}", &cmd); let output = cmd.output(); if let Ok(output) = output { @@ -154,13 +149,13 @@ impl Action for TaskletAction { let out = String::from_utf8_lossy(&output.stdout).trim().to_string(); if !err.is_empty() { - println!( - "\n{}\n", + log::error!( + "{}", if err.len() > self.max_shown_output { format!( "{}\n{}", &err[0..self.max_shown_output].red(), - "".yellow() + "... truncated ...".yellow() ) } else { err.red().to_string() @@ -169,24 +164,28 @@ impl Action for TaskletAction { } if !out.is_empty() { - println!( - "\n{}\n", - if out.len() > self.max_shown_output { - let end = out - .char_indices() - .map(|(i, _)| i) - .nth(self.max_shown_output) - .unwrap(); - let ascii = &out[0..end]; - format!("{}\n{}", ascii, "".yellow()) - } else { - out.to_string() - } - ); + let lines = if out.len() > self.max_shown_output { + let end = out + .char_indices() + .map(|(i, _)| i) + .nth(self.max_shown_output) + .unwrap(); + let ascii = &out[0..end]; + format!("{}\n{}", ascii, "... truncated ...") + } else { + out.to_string() + } + .split('\n') + .map(|s| s.dimmed().to_string()) + .collect::>(); + + for line in lines { + log::info!("{}", line); + } } let exit_code = output.status.code().unwrap_or(0); - // println!("exit_code={}", exit_code); + log::debug!("exit_code={}", exit_code); if exit_code == STATE_COMPLETE_EXIT_CODE { state.lock().await.on_complete(false, Some(out))?; return Ok(Some("task complete".to_string())); @@ -199,8 +198,7 @@ impl Action for TaskletAction { } } else { let err = output.err().unwrap().to_string(); - println!("ERROR: {}", &err); - + log::error!("{}", &err); Err(anyhow!(err)) } } @@ -316,7 +314,7 @@ impl Tasklet { canon.file_stem().unwrap().to_str().unwrap().to_owned() }; - // println!("tasklet = {:?}", &tasklet); + log::debug!("tasklet = {:?}", &tasklet); Ok(tasklet) } @@ -364,11 +362,7 @@ impl Task for Tasklet { if let Ok(tm) = timeout.parse::() { return Some(*tm); } else { - eprintln!( - "{}: can't parse '{}' as duration string", - "WARNING".yellow(), - timeout - ); + log::error!("can't parse '{}' as duration string", timeout); } } None diff --git a/src/cli.rs b/src/cli.rs index 57d297e..7c34222 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -5,8 +5,6 @@ use clap::Parser; use lazy_static::lazy_static; use regex::Regex; -use crate::agent::AgentOptions; - lazy_static! { pub static ref PUBLIC_GENERATOR_PARSER: Regex = Regex::new(r"(?m)^(.+)://(.+)$").unwrap(); pub static ref LOCAL_GENERATOR_PARSER: Regex = @@ -66,15 +64,6 @@ pub(crate) struct Args { } impl Args { - pub fn to_agent_options(&self) -> AgentOptions { - AgentOptions { - max_iterations: self.max_iterations, - save_to: self.save_to.clone(), - full_dump: self.full_dump, - with_stats: self.stats, - } - } - fn parse_connection_string(&self, raw: &str, what: &str) -> Result { let raw = raw.trim().trim_matches(|c| c == '"' || c == '\''); if raw.is_empty() { @@ -153,7 +142,9 @@ impl Args { } pub(crate) fn get_user_input(prompt: &str) -> String { - print!("{}", prompt); + log::warn!("user prompt input required"); + + print!("\n{}", prompt); let _ = io::stdout().flush(); let mut input = String::new(); diff --git a/src/listener.rs b/src/listener.rs new file mode 100644 index 0000000..dd44087 --- /dev/null +++ b/src/listener.rs @@ -0,0 +1,137 @@ +use colored::Colorize; + +use crate::{ + agent::events::{Event, Receiver}, + cli, +}; + +pub(crate) async fn events_listener(args: cli::Args, mut events_rx: Receiver) { + while let Some(event) = events_rx.recv().await { + match event { + Event::MetricsUpdate(metrics) => { + log::info!("{}", metrics); + } + Event::StateUpdate(opts) => { + if let Some(prompt_path) = &args.save_to { + let data = if args.full_dump { + format!( + "[SYSTEM PROMPT]\n\n{}\n\n[PROMPT]\n\n{}\n\n[CHAT]\n\n{}", + &opts.system_prompt, + &opts.prompt, + opts.history + .iter() + .map(|m| m.to_string()) + .collect::>() + .join("\n") + ) + } else { + opts.system_prompt.to_string() + }; + + if let Err(e) = std::fs::write(prompt_path, data) { + log::error!("error writing {}: {:?}", prompt_path, e); + } + } + } + Event::EmptyResponse => { + log::warn!("agent did not provide valid instructions: empty response"); + } + Event::InvalidResponse(response) => { + log::warn!( + "agent did not provide valid instructions: \n\n{}\n\n", + response.dimmed() + ); + } + Event::InvalidAction { invocation, error } => { + log::warn!("invalid action {} : {:?}", invocation.action, error,); + } + Event::ActionTimeout { + invocation, + elapsed, + } => { + log::warn!( + "action '{}' timed out after {:?}", + invocation.action, + elapsed + ); + } + Event::ActionExecuted { + invocation, + error, + result, + elapsed, + } => { + if let Some(err) = error { + log::error!("{}: {}", invocation, err); + } else if let Some(res) = result { + log::debug!( + "{} -> {} bytes in {:?}", + invocation, + res.as_bytes().len(), + elapsed + ); + } else { + log::debug!("{} {} in {:?}", invocation, "no output".dimmed(), elapsed); + } + } + Event::TaskComplete { impossible, reason } => { + if impossible { + log::error!( + "{}: '{}'", + "task is impossible".bold().red(), + if let Some(r) = &reason { + r + } else { + "no reason provided" + } + ); + } else { + log::info!( + "{}: '{}'", + "task complete".bold().green(), + if let Some(r) = &reason { + r + } else { + "no reason provided" + } + ); + } + } + Event::StorageUpdate { + storage_name, + storage_type: _, + key, + prev, + new, + } => { + if prev.is_none() && new.is_none() { + log::info!("storage.{} cleared", storage_name.yellow().bold()); + } else if prev.is_none() && new.is_some() { + log::info!( + "storage.{}.{} > {}", + storage_name.yellow().bold(), + key, + new.unwrap().green() + ); + } else if prev.is_some() && new.is_none() { + log::info!("{}.{} removed", storage_name.yellow().bold(), key); + } else if new.is_some() { + log::info!( + "{}.{} > {}", + storage_name.yellow().bold(), + key, + new.unwrap().green() + ); + } else { + log::info!( + "{}.{} prev={:?} new={:?}", + storage_name.yellow().bold(), + key, + prev, + new + ); + } + } + } + } +} diff --git a/src/main.rs b/src/main.rs index 0f00009..87874e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,13 @@ #[macro_use] extern crate anyhow; +use agent::events::Event; use anyhow::Result; use clap::Parser; -use colored::Colorize; mod agent; mod cli; +mod listener; mod setup; const APP_NAME: &str = env!("CARGO_BIN_NAME"); @@ -15,8 +16,8 @@ const APP_VERSION: &str = env!("CARGO_PKG_VERSION"); #[tokio::main] async fn main() -> Result<()> { // TODO: save/restore session - let args = cli::Args::parse(); + let with_stats = args.stats; if args.generate_doc { // generate action namespaces documentation and exit @@ -24,20 +25,33 @@ async fn main() -> Result<()> { std::process::exit(0); } - let mut agent = setup::setup_agent(&args).await?; + if std::env::var_os("RUST_LOG").is_none() { + // set `RUST_LOG=debug` to see debug logs + std::env::set_var("RUST_LOG", "info,openai_api_rust=warn"); + } + + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) + .format_module_path(false) + .format_target(false) + .init(); + + let (mut agent, events_rx) = setup::setup_agent(&args).await?; + + // spawn the events listener + tokio::spawn(listener::events_listener(args, events_rx)); // keep going until the task is complete or a fatal error is reached while !agent.is_done().await { // next step if let Err(error) = agent.step().await { - println!("{}", error.to_string().bold().red()); + log::error!("{}", error.to_string()); return Err(error); } } // report final metrics on exit - if args.stats { - println!("\n{}", agent.get_metrics().await); + if with_stats { + agent.on_event(Event::MetricsUpdate(agent.get_metrics().await))?; } Ok(()) diff --git a/src/setup.rs b/src/setup.rs index ed18536..07f14c3 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -2,7 +2,12 @@ use anyhow::Result; use colored::Colorize; use crate::{ - agent::{generator, task::tasklet::Tasklet, Agent}, + agent::{ + events::{self, create_channel}, + generator, + task::tasklet::Tasklet, + Agent, + }, cli, APP_NAME, APP_VERSION, }; @@ -37,7 +42,7 @@ fn setup_models( Ok((gen_options, generator, embedder)) } -pub(crate) async fn setup_agent(args: &cli::Args) -> Result { +pub(crate) async fn setup_agent(args: &cli::Args) -> Result<(Agent, events::Receiver)> { // create generator and embedder let (gen_options, generator, embedder) = setup_models(args)?; @@ -52,7 +57,7 @@ pub(crate) async fn setup_agent(args: &cli::Args) -> Result { let tasklet_name = tasklet.name.clone(); println!( - "{} v{} 🧠 {}{} > {}", + "{} v{} 🧠 {}{} > {}\n", APP_NAME, APP_VERSION, gen_options.model_name.bold(), @@ -70,12 +75,11 @@ pub(crate) async fn setup_agent(args: &cli::Args) -> Result { tasklet.prepare(&args.prompt)?; - println!("task: {}\n", tasklet.prompt.as_ref().unwrap().green()); - let task = Box::new(tasklet); + let (tx, rx) = create_channel(); - // create the agent given the generator, embedder, task and a set of options - let agent = Agent::new(generator, embedder, task, args.to_agent_options()).await?; + // create the agent + let agent = Agent::new(tx, generator, embedder, task, args.max_iterations).await?; - Ok(agent) + Ok((agent, rx)) }