Skip to content

Commit

Permalink
new: implemented huggingface message api support (closes #21)
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Aug 30, 2024
1 parent c1640b7 commit a95e546
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 5 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ reqwest_cookie_store = "0.8.0"
serde_json = "1.0.120"

[features]
default = ["ollama", "groq", "openai", "fireworks"]
default = ["ollama", "groq", "openai", "fireworks", "hf"]

ollama = ["dep:ollama-rs"]
groq = ["dep:groq-api-rs", "dep:duration-string"]
openai = ["dep:openai_api_rust"]
fireworks = ["dep:openai_api_rust"]
hf = ["dep:openai_api_rust"]

[profile.release]
lto = true # Enable link-time optimization
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ While Nerve was inspired by other projects such as Autogen and Rigging, its main

## LLM Support

Nerve features integrations for any model accessible via the [ollama](https://github.com/ollama/ollama), [groq](https://groq.com), [OpenAI](https://openai.com/index/openai-api/) and [Fireworks](https://fireworks.ai/) APIs.
Nerve features integrations for any model accessible via the [ollama](https://github.com/ollama/ollama), [groq](https://groq.com), [OpenAI](https://openai.com/index/openai-api/), [Fireworks](https://fireworks.ai/) and [Huggingface](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) APIs.

**The tool will automatically detect if the selected model natively supports function calling. If not, it will provide a compatibility layer that empowers older models to perform function calling anyway.**

Expand Down Expand Up @@ -64,6 +64,14 @@ For **Fireworks**:
LLM_FIREWORKS_KEY=you-api-key nerve -G "fireworks://llama-v3-70b-instruct" ...
```

For **Huggingface**:

Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint.

```sh
HF_API_TOKEN=you-api-key nerve -G "hf://[email protected]" ...
```

## Example

Let's take a look at the `examples/ssh_agent` example tasklet (a "tasklet" is a YAML file describing a task and the instructions):
Expand Down
40 changes: 40 additions & 0 deletions src/agent/generator/huggingface.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use anyhow::Result;
use async_trait::async_trait;

use crate::agent::{state::SharedState, Invocation};

use super::{openai::OpenAIClient, Client, Options};

pub struct HuggingfaceMessageClient {
client: OpenAIClient,
}

#[async_trait]
impl Client for HuggingfaceMessageClient {
fn new(url: &str, _: u16, model_name: &str, _: u32) -> anyhow::Result<Self>
where
Self: Sized,
{
let message_api = format!("https://{}/v1/", url);
let client = OpenAIClient::custom(model_name, "HF_API_TOKEN", &message_api)?;

log::debug!("using huggingface message api @ {}", message_api);

Ok(Self { client })
}

async fn chat(
&self,
state: SharedState,
options: &Options,
) -> anyhow::Result<(String, Vec<Invocation>)> {
self.client.chat(state, options).await
}
}

#[async_trait]
impl mini_rag::Embedder for HuggingfaceMessageClient {
async fn embed(&self, text: &str) -> Result<mini_rag::Embeddings> {
self.client.embed(text).await
}
}
8 changes: 8 additions & 0 deletions src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use super::{state::SharedState, Invocation};
mod fireworks;
#[cfg(feature = "groq")]
mod groq;
#[cfg(feature = "hf")]
mod huggingface;
#[cfg(feature = "ollama")]
mod ollama;
#[cfg(feature = "openai")]
Expand Down Expand Up @@ -153,6 +155,12 @@ macro_rules! factory_body {
$model_name,
$context_window,
)?)),
"hf" => Ok(Box::new(huggingface::HuggingfaceMessageClient::new(
$url,
$port,
$model_name,
$context_window,
)?)),
#[cfg(feature = "groq")]
"groq" => Ok(Box::new(groq::GroqClient::new(
$url,
Expand Down
56 changes: 53 additions & 3 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use regex::Regex;
lazy_static! {
pub static ref PUBLIC_GENERATOR_PARSER: Regex = Regex::new(r"(?m)^(.+)://(.+)$").unwrap();
pub static ref LOCAL_GENERATOR_PARSER: Regex =
Regex::new(r"(?m)^(.+)://(.+)@(.+):(\d+)$").unwrap();
Regex::new(r"(?m)^(.+)://(.+)@([^:]+):?(\d+)?$").unwrap();
}

#[derive(Default)]
Expand All @@ -21,7 +21,7 @@ pub(crate) struct GeneratorOptions {
}

/// Get things done with LLMs.
#[derive(Parser, Debug)]
#[derive(Parser, Debug, Default)]
#[command(version, about, long_about = None)]
pub(crate) struct Args {
/// Generator string as <type>://<model name>@<host>:<port>
Expand Down Expand Up @@ -96,7 +96,11 @@ impl Args {
.unwrap()
.as_str()
.clone_into(&mut generator.host);
generator.port = caps.get(4).unwrap().as_str().parse::<u16>().unwrap();
generator.port = if let Some(port) = caps.get(4) {
port.as_str().parse::<u16>().unwrap()
} else {
0
};
} else {
let caps = if let Some(caps) = PUBLIC_GENERATOR_PARSER.captures_iter(raw).next() {
caps
Expand Down Expand Up @@ -149,3 +153,49 @@ pub(crate) fn get_user_input(prompt: &str) -> String {
println!();
input.trim().to_string()
}

#[cfg(test)]
mod tests {
use super::Args;

#[test]
fn test_wont_parse_invalid_generator() {
let mut args = Args::default();
args.generator = "not a valid generator".to_string();
let ret = args.to_generator_options();
assert!(ret.is_err());
}

#[test]
fn test_parse_local_generator_full() {
let mut args = Args::default();
args.generator = "ollama://llama3@localhost:11434".to_string();
let ret = args.to_generator_options().unwrap();
assert_eq!(ret.type_name, "ollama");
assert_eq!(ret.model_name, "llama3");
assert_eq!(ret.host, "localhost");
assert_eq!(ret.port, 11434);
}

#[test]
fn test_parse_local_generator_without_port() {
let mut args = Args::default();
args.generator = "ollama://llama3@localhost".to_string();
let ret = args.to_generator_options().unwrap();
assert_eq!(ret.type_name, "ollama");
assert_eq!(ret.model_name, "llama3");
assert_eq!(ret.host, "localhost");
assert_eq!(ret.port, 0);
}

#[test]
fn test_parse_public_generator() {
let mut args = Args::default();
args.generator = "groq://llama3".to_string();
let ret = args.to_generator_options().unwrap();
assert_eq!(ret.type_name, "groq");
assert_eq!(ret.model_name, "llama3");
assert_eq!(ret.host, "");
assert_eq!(ret.port, 0);
}
}

0 comments on commit a95e546

Please sign in to comment.