Skip to content

Commit

Permalink
new: added rayon parallelization for vector search
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jun 26, 2024
1 parent 501d09b commit e03512f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 7 deletions.
46 changes: 46 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ simple-home-dir = "0.3.5"
tokio = "1.38.0"
xml-rs = "0.8.20"
duration-string = { version = "0.4.0", optional = true }
rayon = { version = "1.10.0", optional = true }
glob = "0.3.1"

[features]
default = ["ollama", "groq", "openai", "fireworks"]
default = ["ollama", "groq", "openai", "fireworks", "rayon"]

ollama = ["dep:ollama-rs"]
groq = ["dep:groq-api-rs", "dep:duration-string"]
openai = ["dep:openai_api_rust"]
fireworks = ["dep:openai_api_rust"]

rayon = ["dep:rayon"]

[profile.release]
lto = true # Enable link-time optimization
Expand Down
25 changes: 20 additions & 5 deletions src/agent/rag/naive.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::{collections::HashMap, time::Instant};

#[cfg(feature = "rayon")]
use rayon::prelude::*;

use anyhow::Result;
use async_trait::async_trait;
use colored::Colorize;
Expand Down Expand Up @@ -89,13 +92,25 @@ impl VectorStore for NaiveVectorStore {
println!("[{}] {} (top {})", "rag".bold(), query, top_k);

let query_vector = self.embedder.embeddings(query).await?;
let mut distances = vec![];
let mut results = vec![];

// TODO: parallelize?
for (doc_name, doc_embedding) in &self.embeddings {
distances.push((doc_name, metrics::cosine(&query_vector, doc_embedding)));
}
#[cfg(feature = "rayon")]
let mut distances: Vec<(&String, f64)> = self
.embeddings
.par_iter()
.map(|(doc_name, doc_embedding)| {
(doc_name, metrics::cosine(&query_vector, doc_embedding))
})
.collect();

#[cfg(not(feature = "rayon"))]
let mut distances = {
let mut distances = vec![];
for (doc_name, doc_embedding) in &self.embeddings {
distances.push((doc_name, metrics::cosine(&query_vector, doc_embedding)));
}
distances
};

distances.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());

Expand Down

0 comments on commit e03512f

Please sign in to comment.