Skip to content

Commit

Permalink
new: added unit tests for http namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jul 8, 2024
1 parent 02121cf commit 35c63c7
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 15 deletions.
231 changes: 216 additions & 15 deletions src/agent/namespaces/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use crate::agent::state::SharedState;

use super::{Action, Namespace, StorageDescriptor};

const DEFAULT_HTTP_SCHEMA: &str = "https";

#[derive(Debug, Default, Clone)]
struct ClearHeaders {}

Expand Down Expand Up @@ -97,7 +99,7 @@ impl Request {

// add schema if not present
if !http_target.contains("://") {
http_target = format!("http://{http_target}");
http_target = format!("{DEFAULT_HTTP_SCHEMA}://{http_target}");
}

Url::parse(&http_target)
Expand Down Expand Up @@ -148,6 +150,23 @@ impl Request {

Ok((reason.to_string(), resp))
}

fn create_request(method: &str, target_url: Url) -> Result<reqwest::RequestBuilder> {
let method = reqwest::Method::from_str(method)?;
let mut request = reqwest::Client::new().request(method.clone(), target_url.clone());
let query_str = target_url.query().unwrap_or("").to_string();

// if there're parameters and we're not in GET, set them as the body
if !query_str.is_empty() && !matches!(method, reqwest::Method::GET) {
request = request.header(
reqwest::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
);
request = request.body(query_str);
}

Ok(request)
}
}

#[async_trait]
Expand Down Expand Up @@ -188,33 +207,23 @@ impl Action for Request {
) -> Result<Option<String>> {
// create a parsed Url from the attributes, payload and HTTP_TARGET variable
let attrs = attrs.unwrap();
let method = reqwest::Method::from_str(attrs.get("method").unwrap())?;
let method = attrs.get("method").unwrap();
let target_url = Self::create_target_url_from(&state, payload.clone()).await?;
let query_str = target_url.query().unwrap_or("").to_string();
let target_url_str = target_url.to_string();
let mut request = Self::create_request(method, target_url)?;

// TODO: handle cookie/session persistency

let mut request = reqwest::Client::new().request(method.clone(), target_url.clone());

// add defined headers
for (key, value) in state.lock().await.get_storage("http-headers")?.iter() {
request = request.header(key, &value.data);
}

// if there're parameters and we're not in GET, set them as the body
if !query_str.is_empty() && !matches!(method, reqwest::Method::GET) {
request = request.header(
reqwest::header::CONTENT_TYPE,
"application/x-www-form-urlencoded",
);
request = request.body(query_str);
}

log::info!(
"{}.{} {} ...",
"http".bold(),
method.to_string().yellow(),
target_url.to_string(),
target_url_str,
);

// perform the request
Expand Down Expand Up @@ -262,3 +271,195 @@ pub(crate) fn get_namespace() -> Namespace {
]),
)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::agent::state::State;

use super::*;

#[derive(Debug)]
struct TestTask {}

impl crate::agent::task::Task for TestTask {
fn to_system_prompt(&self) -> Result<String> {
Ok("test".to_string())
}

fn to_prompt(&self) -> Result<String> {
Ok("test".to_string())
}

fn get_functions(&self) -> Vec<Namespace> {
vec![]
}
}

struct TestEmbedder {}

#[async_trait]
impl mini_rag::Embedder for TestEmbedder {
async fn embed(&self, _text: &str) -> Result<mini_rag::Embeddings> {
todo!()
}
}

#[allow(unused_variables)]
async fn create_test_state(vars: Vec<(String, String)>) -> Result<SharedState> {
let (tx, _rx) = crate::agent::events::create_channel();

let task = Box::new(TestTask {});
let embedder = Box::new(TestEmbedder {});

let mut state = State::new(tx, task, embedder, 10).await?;

for (name, value) in vars {
state.set_variable(name, value);
}

Ok(Arc::new(tokio::sync::Mutex::new(state)))
}

#[tokio::test]
async fn test_parse_no_target() {
let state = create_test_state(vec![]).await.unwrap();
let payload = Some("/".to_string());
let target_url = Request::create_target_url_from(&state, payload.clone()).await;

assert!(target_url.is_err());
}

#[tokio::test]
async fn test_parse_simple_get_without_schema() {
let state = create_test_state(vec![(
"HTTP_TARGET".to_string(),
"www.example.com".to_string(),
)])
.await
.unwrap();

let payload = Some("/".to_string());
let target_url = Request::create_target_url_from(&state, payload.clone())
.await
.unwrap();

assert_eq!(
target_url.to_string(),
format!("{DEFAULT_HTTP_SCHEMA}://www.example.com/")
);
}

#[tokio::test]
async fn test_parse_simple_get_with_schema() {
let state = create_test_state(vec![(
"HTTP_TARGET".to_string(),
"ftp://www.example.com".to_string(),
)])
.await
.unwrap();

let payload = Some("/".to_string());
let target_url = Request::create_target_url_from(&state, payload.clone())
.await
.unwrap();

assert_eq!(target_url.to_string(), format!("ftp://www.example.com/"));
}

#[tokio::test]
async fn test_parse_simple_get_with_schema_and_port() {
let state = create_test_state(vec![(
"HTTP_TARGET".to_string(),
"ftp://www.example.com:1012".to_string(),
)])
.await
.unwrap();

let payload = Some("/".to_string());
let target_url = Request::create_target_url_from(&state, payload.clone())
.await
.unwrap();

assert_eq!(
target_url.to_string(),
format!("ftp://www.example.com:1012/")
);
}

#[tokio::test]
async fn test_parse_query_string() {
let state = create_test_state(vec![(
"HTTP_TARGET".to_string(),
"www.example.com".to_string(),
)])
.await
.unwrap();

let payload = Some("/index.php?id=1&name=foo".to_string());
let target_url = Request::create_target_url_from(&state, payload.clone())
.await
.unwrap();

assert_eq!(
target_url.to_string(),
format!("{DEFAULT_HTTP_SCHEMA}://www.example.com/index.php?id=1&name=foo")
);
}

#[tokio::test]
async fn test_parse_query_string_is_escaped() {
let state = create_test_state(vec![(
"HTTP_TARGET".to_string(),
"www.example.com".to_string(),
)])
.await
.unwrap();

let payload = Some("/index.php?id=1&name=foo' or ''='".to_string());
let target_url = Request::create_target_url_from(&state, payload.clone())
.await
.unwrap();

assert_eq!(
target_url.to_string(),
format!("{DEFAULT_HTTP_SCHEMA}://www.example.com/index.php?id=1&name=foo%27%20or%20%27%27=%27")
);
}
#[tokio::test]
async fn test_parse_body_post() {
let state = create_test_state(vec![(
"HTTP_TARGET".to_string(),
"www.example.com".to_string(),
)])
.await
.unwrap();

let method = "POST";
let payload = Some("/login.php?user=admin&pass=' OR ''='".to_string());
let target_url = Request::create_target_url_from(&state, payload.clone())
.await
.unwrap();
let expected_body_string = "user=admin&pass=%27%20OR%20%27%27=%27".to_string();
let expected_target_url_string = format!(
"{DEFAULT_HTTP_SCHEMA}://www.example.com/login.php?{}",
expected_body_string
);

assert_eq!(target_url.to_string(), expected_target_url_string);

let request = Request::create_request(method, target_url)
.unwrap()
.build()
.unwrap();

assert_eq!(request.method().to_string(), method.to_string());
assert_eq!(request.url().to_string(), expected_target_url_string);
assert!(request.body().is_some());
assert_eq!(
request.body().unwrap().as_bytes(),
Some(expected_body_string.as_bytes())
);
}
}
5 changes: 5 additions & 0 deletions src/agent/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ impl State {
self.variables.get(name)
}

#[allow(dead_code)]
pub fn set_variable(&mut self, name: String, value: String) {
self.variables.insert(name, value);
}

pub fn get_storages(&self) -> Vec<&Storage> {
self.storages.values().collect()
}
Expand Down

0 comments on commit 35c63c7

Please sign in to comment.