Skip to content

Commit

Permalink
working at #9
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jul 18, 2024
1 parent 00d6633 commit 16085e2
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 162 deletions.
57 changes: 35 additions & 22 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 2 additions & 6 deletions cake-core/src/cake/master.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ impl<G: Generator + Send + Sync + 'static> Master<G> {
} else {
// if running in cli mode, pre add system and user prompts
self.model
.add_message(Message::system(self.ctx.args.prompt.clone()))?;
.add_message(Message::system(self.ctx.args.system_prompt.clone()))?;
self.model
.add_message(Message::user(self.ctx.args.system_prompt.clone()))?;
.add_message(Message::user(self.ctx.args.prompt.clone()))?;

// just run one generation to stdout
self.generate(|data| {
Expand Down Expand Up @@ -80,10 +80,6 @@ impl<G: Generator + Send + Sync + 'static> Master<G> {
}
}

if let Some(rest) = self.model.last().await? {
stream(&rest);
}

// signal end of stream
stream("");

Expand Down
2 changes: 1 addition & 1 deletion cake-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct Args {
#[arg(long, default_value = "./cake-data/topology.yml")]
pub topology: String,
/// The initial prompt.
#[arg(long, default_value = "Why is the sky blue?")]
#[arg(long, default_value = "The sky is blue because ")]
pub prompt: String,
/// The system prompt.
#[arg(long, default_value = "You are a helpful AI assistant.")]
Expand Down
47 changes: 47 additions & 0 deletions cake-core/src/models/llama3/history.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use crate::models::chat::Message;

/// Chat history.
pub struct History(Vec<Message>);

// Adapted from https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202
impl History {
fn encode_header(message: &Message) -> String {
format!("<|start_header_id|>{}<|end_header_id|>\n\n", message.role)
}

fn encode_message(message: &Message) -> String {
Self::encode_header(message) + message.content.trim() + "<|eot_id|>"
}

/// Create a new instance of this object.
pub fn new() -> Self {
Self(vec![])
}

/// Encode the dialog to llama3 prompt format.
pub fn encode_dialog_to_prompt(&self) -> String {
let mut encoded = "<|begin_of_text|>".to_string();

for message in self.iter() {
encoded += &Self::encode_message(message);
}

// Add the start of an assistant message for the model to complete.
encoded += &Self::encode_header(&Message::assistant("".to_string()));

encoded
}
}

impl std::ops::Deref for History {
type Target = Vec<Message>;
fn deref(&self) -> &Vec<Message> {
&self.0
}
}

impl std::ops::DerefMut for History {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
63 changes: 28 additions & 35 deletions cake-core/src/models/llama3/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,25 @@ use tokenizers::Tokenizer;
use crate::{
cake::{Context, Forwarder},
models::{chat::Message, Generator, Token},
utils::{self, TokenOutputStream},
};

use super::transformer::Transformer;
use super::{transformer::Transformer, History};

/// End of stream token.
const EOS_TOKEN: &str = "</s>";
/// Default end of stream token if not found in configuration.
const DEFAULT_EOS_TOKEN: &str = "</s>";

/// Load the tokenizer and return the first tokens from the prompt in context.
fn load_tokenizer(ctx: &Context) -> Result<(TokenOutputStream, Option<u32>)> {
fn load_tokenizer(ctx: &Context) -> Result<(Tokenizer, Option<u32>)> {
let tokenizer_filename = ctx.data_path.join("tokenizer.json");

log::info!("loading tokenizer from {}", tokenizer_filename.display());

let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;

let eos_token_id = ctx
.config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));

let tokenizer = utils::TokenOutputStream::new(tokenizer);
.or_else(|| tokenizer.token_to_id(DEFAULT_EOS_TOKEN));

Ok((tokenizer, eos_token_id))
}
Expand All @@ -53,7 +51,7 @@ fn create_logits_processor(ctx: &Context) -> LogitsProcessor {
pub struct LLama {
ctx: Context,

tokenizer: TokenOutputStream,
tokenizer: Tokenizer,
embedding: Embedding,
eos_token_id: Option<u32>,
index_pos: usize,
Expand All @@ -66,7 +64,7 @@ pub struct LLama {

logits_processor: LogitsProcessor,

history: Vec<Message>,
history: History,
tokens: Vec<u32>,
}

Expand Down Expand Up @@ -138,13 +136,6 @@ impl LLama {
.to_dtype(DType::F32)
.map_err(|e| anyhow!("error converting logits: {e}"))
}

fn raw_message(&self, message: &Message) -> String {
format!(
"<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>\n",
message.role, &message.content
)
}
}

#[async_trait]
Expand Down Expand Up @@ -204,7 +195,7 @@ impl Generator for LLama {

let (tokenizer, eos_token_id) = load_tokenizer(&ctx)?;
let tokens = vec![];
let history = vec![];
let history = History::new();

let logits_processor = create_logits_processor(&ctx);
let index_pos = 0;
Expand Down Expand Up @@ -238,10 +229,10 @@ impl Generator for LLama {
Ok(())
}

/// Reset the chat pipeline state.
fn reset(&mut self) -> Result<()> {
self.tokens.clear();
self.history.clear();
self.tokenizer.clear();
self.ctx.cache.clear();
self.index_pos = 0;
self.generated = 0;
Expand All @@ -254,28 +245,29 @@ impl Generator for LLama {

// Prefill tokens with chat history the first time.
if self.generated == 0 {
// make sure we start clean
self.tokens.clear();
self.ctx.cache.clear();
self.index_pos = 0;
self.generated = 0;

log::debug!("generating history tokens ...");

// generate raw from history
let mut raw = "<|begin_of_text|>".to_string();

for message in &self.history {
raw += &self.raw_message(message);
}

raw += "<|start_header_id|>assistant<|end_header_id|>";
let dialog = self.history.encode_dialog_to_prompt();

log::debug!("{}", &raw);
log::debug!("dialog={}", &dialog);

// tokenize raw
self.tokens = self
.tokenizer
.tokenizer()
.encode(raw, true)
.encode(dialog, false) // do not add special tokens as we already added them
.map_err(anyhow::Error::msg)?
.get_ids()
.to_vec();

log::debug!("encoded={:?}", &self.tokens);

log::debug!("history tokens: {}", self.tokens.len());
}

Expand Down Expand Up @@ -326,16 +318,17 @@ impl Generator for LLama {

Ok(Token {
id: next_token,
text: self.tokenizer.next_token(next_token)?,
text: match self.tokenizer.decode(&[next_token], false) {
Ok(s) => Some(s),
Err(e) => {
log::error!("could not decode token {next_token}: {e}");
None
}
},
is_end_of_stream: Some(next_token) == self.eos_token_id,
})
}

/// Return any resitual token if necessary or None if not.
async fn last(&mut self) -> Result<Option<String>> {
self.tokenizer.decode_rest().map_err(anyhow::Error::msg)
}

/// Return the number of generated tokens so far.
fn generated_tokens(&self) -> usize {
self.generated
Expand Down
Loading

0 comments on commit 16085e2

Please sign in to comment.