From a95e5461c350ace8280b3ce897502c4488b88266 Mon Sep 17 00:00:00 2001 From: Simone Margaritelli Date: Fri, 30 Aug 2024 10:21:47 +0200 Subject: [PATCH] new: implemented huggingface message api support (closes #21) --- Cargo.toml | 3 +- README.md | 10 +++++- src/agent/generator/huggingface.rs | 40 +++++++++++++++++++++ src/agent/generator/mod.rs | 8 +++++ src/cli.rs | 56 ++++++++++++++++++++++++++++-- 5 files changed, 112 insertions(+), 5 deletions(-) create mode 100644 src/agent/generator/huggingface.rs diff --git a/Cargo.toml b/Cargo.toml index b6488ac..d9ce9ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,12 +55,13 @@ reqwest_cookie_store = "0.8.0" serde_json = "1.0.120" [features] -default = ["ollama", "groq", "openai", "fireworks"] +default = ["ollama", "groq", "openai", "fireworks", "hf"] ollama = ["dep:ollama-rs"] groq = ["dep:groq-api-rs", "dep:duration-string"] openai = ["dep:openai_api_rust"] fireworks = ["dep:openai_api_rust"] +hf = ["dep:openai_api_rust"] [profile.release] lto = true # Enable link-time optimization diff --git a/README.md b/README.md index 58200a6..cac1835 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ While Nerve was inspired by other projects such as Autogen and Rigging, its main ## LLM Support -Nerve features integrations for any model accessible via the [ollama](https://github.com/ollama/ollama), [groq](https://groq.com), [OpenAI](https://openai.com/index/openai-api/) and [Fireworks](https://fireworks.ai/) APIs. +Nerve features integrations for any model accessible via the [ollama](https://github.com/ollama/ollama), [groq](https://groq.com), [OpenAI](https://openai.com/index/openai-api/), [Fireworks](https://fireworks.ai/) and [Huggingface](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) APIs. **The tool will automatically detect if the selected model natively supports function calling. If not, it will provide a compatibility layer that empowers older models to perform function calling anyway.** @@ -64,6 +64,14 @@ For **Fireworks**: LLM_FIREWORKS_KEY=you-api-key nerve -G "fireworks://llama-v3-70b-instruct" ... ``` +For **Huggingface**: + +Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint. + +```sh +HF_API_TOKEN=you-api-key nerve -G "hf://tgi@your-custom-endpoint.aws.endpoints.huggingface.cloud" ... +``` + ## Example Let's take a look at the `examples/ssh_agent` example tasklet (a "tasklet" is a YAML file describing a task and the instructions): diff --git a/src/agent/generator/huggingface.rs b/src/agent/generator/huggingface.rs new file mode 100644 index 0000000..33dc490 --- /dev/null +++ b/src/agent/generator/huggingface.rs @@ -0,0 +1,40 @@ +use anyhow::Result; +use async_trait::async_trait; + +use crate::agent::{state::SharedState, Invocation}; + +use super::{openai::OpenAIClient, Client, Options}; + +pub struct HuggingfaceMessageClient { + client: OpenAIClient, +} + +#[async_trait] +impl Client for HuggingfaceMessageClient { + fn new(url: &str, _: u16, model_name: &str, _: u32) -> anyhow::Result + where + Self: Sized, + { + let message_api = format!("https://{}/v1/", url); + let client = OpenAIClient::custom(model_name, "HF_API_TOKEN", &message_api)?; + + log::debug!("using huggingface message api @ {}", message_api); + + Ok(Self { client }) + } + + async fn chat( + &self, + state: SharedState, + options: &Options, + ) -> anyhow::Result<(String, Vec)> { + self.client.chat(state, options).await + } +} + +#[async_trait] +impl mini_rag::Embedder for HuggingfaceMessageClient { + async fn embed(&self, text: &str) -> Result { + self.client.embed(text).await + } +} diff --git a/src/agent/generator/mod.rs b/src/agent/generator/mod.rs index baeca27..f0873be 100644 --- a/src/agent/generator/mod.rs +++ b/src/agent/generator/mod.rs @@ -13,6 +13,8 @@ use super::{state::SharedState, Invocation}; mod fireworks; #[cfg(feature = "groq")] mod groq; +#[cfg(feature = "hf")] +mod huggingface; #[cfg(feature = "ollama")] mod ollama; #[cfg(feature = "openai")] @@ -153,6 +155,12 @@ macro_rules! factory_body { $model_name, $context_window, )?)), + "hf" => Ok(Box::new(huggingface::HuggingfaceMessageClient::new( + $url, + $port, + $model_name, + $context_window, + )?)), #[cfg(feature = "groq")] "groq" => Ok(Box::new(groq::GroqClient::new( $url, diff --git a/src/cli.rs b/src/cli.rs index b8d6698..68d2f97 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -8,7 +8,7 @@ use regex::Regex; lazy_static! { pub static ref PUBLIC_GENERATOR_PARSER: Regex = Regex::new(r"(?m)^(.+)://(.+)$").unwrap(); pub static ref LOCAL_GENERATOR_PARSER: Regex = - Regex::new(r"(?m)^(.+)://(.+)@(.+):(\d+)$").unwrap(); + Regex::new(r"(?m)^(.+)://(.+)@([^:]+):?(\d+)?$").unwrap(); } #[derive(Default)] @@ -21,7 +21,7 @@ pub(crate) struct GeneratorOptions { } /// Get things done with LLMs. -#[derive(Parser, Debug)] +#[derive(Parser, Debug, Default)] #[command(version, about, long_about = None)] pub(crate) struct Args { /// Generator string as ://@: @@ -96,7 +96,11 @@ impl Args { .unwrap() .as_str() .clone_into(&mut generator.host); - generator.port = caps.get(4).unwrap().as_str().parse::().unwrap(); + generator.port = if let Some(port) = caps.get(4) { + port.as_str().parse::().unwrap() + } else { + 0 + }; } else { let caps = if let Some(caps) = PUBLIC_GENERATOR_PARSER.captures_iter(raw).next() { caps @@ -149,3 +153,49 @@ pub(crate) fn get_user_input(prompt: &str) -> String { println!(); input.trim().to_string() } + +#[cfg(test)] +mod tests { + use super::Args; + + #[test] + fn test_wont_parse_invalid_generator() { + let mut args = Args::default(); + args.generator = "not a valid generator".to_string(); + let ret = args.to_generator_options(); + assert!(ret.is_err()); + } + + #[test] + fn test_parse_local_generator_full() { + let mut args = Args::default(); + args.generator = "ollama://llama3@localhost:11434".to_string(); + let ret = args.to_generator_options().unwrap(); + assert_eq!(ret.type_name, "ollama"); + assert_eq!(ret.model_name, "llama3"); + assert_eq!(ret.host, "localhost"); + assert_eq!(ret.port, 11434); + } + + #[test] + fn test_parse_local_generator_without_port() { + let mut args = Args::default(); + args.generator = "ollama://llama3@localhost".to_string(); + let ret = args.to_generator_options().unwrap(); + assert_eq!(ret.type_name, "ollama"); + assert_eq!(ret.model_name, "llama3"); + assert_eq!(ret.host, "localhost"); + assert_eq!(ret.port, 0); + } + + #[test] + fn test_parse_public_generator() { + let mut args = Args::default(); + args.generator = "groq://llama3".to_string(); + let ret = args.to_generator_options().unwrap(); + assert_eq!(ret.type_name, "groq"); + assert_eq!(ret.model_name, "llama3"); + assert_eq!(ret.host, ""); + assert_eq!(ret.port, 0); + } +}