Jekyll2023-12-18T16:05:30+00:00https://ianxmason.github.io/feed.xmlIan MasonArtificial intelligence and machine learning research.Ian MasonBuilding a Minimal RAG Model2023-10-21T00:00:00+01:002023-10-21T00:00:00+01:00https://ianxmason.github.io/posts/RAG<p>Large language models (LLMs) like ChatGPT are very good at generating cohesive text on a wide range of topics. Often, however, we want to generate text for a very specific use case. For example, imagine we want a model that is able to answer factual questions about historical financial data. When we ask ChatGPT the question <em>“What was the inflation rate in Indonesia in 1986?”</em>, it states that it doesn’t have enough information to provide a good answer. Other LLMs might give a reasonable looking, but factually inaccurate, answer as LLMs can be <a href="https://arxiv.org/abs/2302.03494" target="_blank">prone to</a> <a href="https://arxiv.org/abs/2302.04023" target="_blank">hallucinations</a>.</p>
<p><img src="/assets/images/RAG/chatgpt.png" alt="Indonesia Chat" /></p>
<p>If we want to build a model that can answer questions like this we have a couple of options. We could fine-tune an LLM on a dataset of questions and answers, but this requires a lot of data, can be expensive, and we may lose some generality in the text the LLM can generate. Alternatively, we could augment the model with external <a href="https://arxiv.org/abs/1909.00109" target="_blank">tools</a> <a href="https://arxiv.org/pdf/2205.12255.pdf" target="_blank">or</a> <a href="https://arxiv.org/abs/2303.17580" target="_blank">resources</a>.</p>
<p><strong>RAG</strong> models (retrieval augmented generation) aim to augment language models by providing them with additional context with which to respond to a user query. <a href="https://arxiv.org/abs/2005.11401" target="_blank">Originally designed</a> to include a training loop, RAG models now tend more towards stitching together pre-trained components with a vector database. The below figure (which is taken from a more <a href="https://www.anyscale.com/blog/a-comprehensive-guide-for-building-rag-based-llm-applications-part-1" target="_blank">detailed tutorial</a> on building RAG models at scale) shows the main components and structure of a RAG model.</p>
<p><img src="/assets/images/RAG/anyscale_rag.png" alt="Rag Model" /></p>
<ul>
<li><em>Set-Up:</em> A vector database (VectorDB) is created offline by embedding relevant resources/documents using a pre-trained <a href="https://platform.openai.com/docs/guides/embeddings" target="_blank">language embedding model</a> that converts text to vector embeddings.</li>
<li><em>Step One:</em> A user query is received and embedded with the same embedding model used to create the VectorDB</li>
<li><em>Step Two:</em> The VectorDB is searched to find the document(s) with the most similar embedding(s). (Using cosine similarity, L2 distance or similar.)</li>
<li><em>Step Three:</em> The documents retrieved from the VectorDB are added to the user’s query as additional context.</li>
<li><em>Step Four:</em> The original query augmented with the documents is fed into an LLM.</li>
<li><em>Step Five:</em> The LLM should now generate a more reasonable response using the additional information from an outside source.</li>
</ul>
<h2 id="spinning-up-a-minimal-example">Spinning Up a Minimal Example</h2>
<p>To see what this actually looks like in practice, we will walk through how to build a very simple RAG model making use of available tools. The documents we embed to create the VectorDB are a small number of news documents from the <a href="https://www.nltk.org/book/ch02.html" target="_blank">nltk reuters</a> corpus. To build the VectorDB we use <a href="https://www.trychroma.com" target="_blank">ChromaDB</a> which allows to create a locally stored vector database with a few lines of code. We use the <a href="https://platform.openai.com/docs/introduction" target="_blank">OpenAI API</a> to access powerful models for embedding documents and generating text.</p>
<h3 id="creating-a-vectordb">Creating a VectorDB</h3>
<p>A VectorDB is just a database where each entry haa a vector associated with it allowing us to search the database to find the entry with the closest vector. To keep costs low, for this example we will take the first 100 documents from the nltk reuters corpus and discard any documents with more than 500 words. Returning to our original example on Indonesian inflation, one of the documents in this set is shown below. We can see how if we are able to provide this context to the LLM we should be able to get a good answer to our question about past inflation rates in Indonesia.</p>
<p><img src="/assets/images/RAG/indonesia_doc.png" alt="Indonesian Inflation" /></p>
<p>First, we set up a method to embed documents using the OpenAI <code class="language-plaintext highlighter-rouge">text-embedding-ada-002</code> model.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import openai
from chromadb.utils import embedding_functions
def get_embedding_function():
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=get_openai_key(),
model_name="text-embedding-ada-002"
)
return openai_ef
</code></pre></div></div>
<p>With this embedding function, the below code creates a locally stored VectorDB which stores the raw text of the reuters documents, an id for each document and the vector created by the embedding model.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import chromadb
import nltk
from nltk.corpus import reuters
from utils import get_embedding_function
nltk.download('reuters')
reuters_subset = reuters.fileids()[0:100]
reuters_subset = [id for id in reuters_subset if len(reuters.words(id)) < 500]
client = chromadb.PersistentClient(path="chromadb/test_db")
collection = client.create_collection(name="reuters_collection", embedding_function=get_embedding_function())
for i, file_id in enumerate(reuters_subset):
collection.add(
documents=[reuters.raw(file_id)],
metadatas=[{"nltk_file_id": file_id}],
ids=[str(i)]
)
print(collection.peek()) # To see the first documents in the collection
</code></pre></div></div>
<h3 id="building-a-rag-model">Building a RAG model</h3>
<p>Now we have constructed our VectorDB to store relevant documents, when we receive a user query we query the VectorDB, get the top 3 most similar results, and then add the top 3 results to the user query to generate a response. In the below code <code class="language-plaintext highlighter-rouge">get_rag_context</code> embeds the user query and finds the most similar 3 documents to use as additional context. The method <code class="language-plaintext highlighter-rouge">rag_response</code> gets a response from the LLM (<code class="language-plaintext highlighter-rouge">gpt-3.5-turbo</code>) when these documents are provided alongside the query. For comparison the method <code class="language-plaintext highlighter-rouge">response</code> gets a response using only the user query with no additional information.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>import openai
import chromadb
from utils import get_embedding_function
def response(query):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{query}"},
]
)
return response['choices'][0]['message']['content']
def rag_response(query, context):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant. Please answer the query using the context provided."},
{"role": "user", "content": f"query: {query}. context: {context}"},
]
)
return response['choices'][0]['message']['content']
def get_rag_context(query, client, num_docs=3):
collection = client.get_collection(name="reuters_collection", embedding_function=get_embedding_function())
results = collection.query(
query_texts=[query],
n_results=num_docs
)
contexts = [doc.replace("\n", " ") for doc in results['documents'][0]]
return contexts
def main():
client = chromadb.PersistentClient(path="../chromadb/test_db")
query = "What was the inflation rate in Indonesia in 1986?"
contexts = get_rag_context(query, client)
default_response = response(query)
ragged_response = rag_response(query, ";".join(contexts))
print(f"Query: {query}")
print(f"Default response: {default_response}")
print(f"RAG response: {ragged_response}")
if __name__ == "__main__":
main()
</code></pre></div></div>
<h3 id="rag-in-action">RAG in action</h3>
<p>Now we can see what happens when we ask the RAG model our original question about the inflation rate in Indonesia in 1986. We see here that with the additional context the LLM is able to answer the question with the correct answer (8.8%) whereas the default LLM without augmented context is unable to provide a definitive answer.</p>
<p><img src="/assets/images/RAG/rag_response.png" alt="Rag Response" /></p>
<h2 id="summary">Summary</h2>
<p>In around 100 lines of code we have augmented a large language model to use external resources when answering user queries. We have seen that with the right additional context LLMs are able to better answer specific technical questions.</p>
<p>This simple code is only possible because of recent improvements in the tooling available for building these models. We can query large models in a few lines of code with the OpenAI API and build a VectorDB quickly with ChromaDB. Recently, OpenAI have even launched a beta for their <a href="https://platform.openai.com/docs/assistants/overview" target="_blank">Assistants API</a> where you can use retrieval without ever leaving the OpenAI ecosystem.</p>
<p>RAG models are a useful tool for building applications with LLMs right now. However, despite the release of GPT-4 which is bigger and better than GPT-3.5 used here, hallucinations and inability to answer factual questions still remain a problem. Whether RAG-style approaches will lead to more general AI systems I am less sure, currently I tend to lean more towards scaling self-prediction in end-to-end systems.</p>
<hr />
<h3 id="code">Code</h3>
<p>A repository containing the minimal RAG model discussed in this post can be found <a href="https://github.com/ianxmason/minimal-rag-model" target="_blank">here</a>.</p>Ian MasonA simple demo for retrieval augmented generation in 100 lines of codePeriodic Autoencoder - Explanation and Addendum2022-04-09T00:00:00+01:002022-04-09T00:00:00+01:00https://ianxmason.github.io/posts/PAE<p>In <a href="https://ianxmason.github.io/papers/deep_phase.pdf" target="_blank">DeepPhase</a> we use a Periodic Autoencoder (PAE) to learn periodic features from data. In particular, we aim to extract a small number of phase values that well-capture the (non-linear) periodicity of higher dimensional time series data. If you are solely interested in implementing the Periodic Autoencoder for your use case look for the supplementary material in the <a href="https://dl.acm.org/doi/10.1145/3528223.3530178" target="_blank">ACM Digital Library</a>. What follows is a brief explanation of how we can extract phase features from data.</p>
<p>Given some number of temporal signals, which are assumed to have some joint periodicity, the Periodic Autoencoder learns a latent space, $\mathbf{L}$, with a few (say 5) latent signals of the same length as the original signals. These latent signals might look something like the signals in the left plot below. For each one of these signals we then aim to extract a good phase offset that captures its current location as part of a larger cycle.</p>
<p>However, any one of these signals (say the blue curve in the right plot) may not have a single obvious frequency, amplitude, or phase offset. So the question becomes, what is a good way to calculate a phase value for such a curve?</p>
<p><img src="/assets/images/PAE/latent_signals.png" alt="Latent Signals" /></p>
<p>The approach taken with the Periodic Autoencoder is to approximate each latent signal with a sinusoidal function $\Gamma(x) = A sin (2 \pi (Fx - S)) + B$, parameterized by $A$, $F$, $B$ & $S$, and to then use $S$ as the phase offset. The plot below shows the same blue curve from above along with the function $\Gamma(x) = 0.01 sin(4 \pi x)$. Our aim is to better set the parameters of $\Gamma$ such that the orange curve more closely approximates the blue curve. After calculating these sinusoidal functions for all latent signals the PAE then aims to reconstruct the original data from these functions. This places a big inductive bias on the latent space that assumes a few periodic functions will be sufficient for reconstruction.</p>
<p><img src="/assets/images/PAE/chosen_signal.png" alt="Chosen Signal" /></p>
<p>So, to find the parameters for $\Gamma$, let’s say this $1D$ signal (blue curve) contains $N$ points over a time window of $T$ seconds, then after applying a real discrete Fourier transform [1] we receive $K+1$ Fourier coefficients $\mathbf{C}=[c_0, c_1, \dots, c_K]$ where $K = \left\lfloor\frac{N}{2}\right\rfloor$. These coefficients correspond to frequency bins centered at $\mathbf{f}=[0, 1/T, 2/T, \dots, K/T]$ Hz (the real DFT does not calculate the negative frequency terms as they are redundant for real-valued signals, returning only a single-sided spectrum).</p>
<p>The magnitudes of the coefficients, $\mathbf{m} = [m_0, m_1, \dots, m_K]$ $= [|c_0|, |c_1|, \dots, |c_K|]$, represent the relative presence of each of the frequency bins in the signal. From this we can calculate the power spectrum (the amount of the signal’s power present in each of the frequency bins) as $\mathbf{p} = [p_0, p_1, \dots, p_K] =\frac{2}{N} [\frac{1}{2} m_0^2, m_1^2, \dots, m_K^2]$. Note that every term except the $m_0$ term is doubled since the real DFT returns the single-sided spectrum and we still wish to account for the power in the double-sided spectrum [2]. Below we show the whole single sided power spectrum in the left plot and the power spectrum without the zero frequency bin in the right plot.</p>
<p><img src="/assets/images/PAE/power_spectrum.png" alt="Power Spectrum" /></p>
<p>$c_0$ is the “zero frequency” Fourier coefficient, which is equivalently the sum of the $N$ samples and is always real. Therefore, dividing by the number of samples gives us the mean offset of the signal which we use as the offset for the sinusoidal approximation $\Gamma$:</p>
\[B = \frac{c_0}{N}.\]
<p>After applying this value of $B$, we see our approximation becomes more well aligned along the y-axis.</p>
<p><img src="/assets/images/PAE/signal_with_B.png" alt="Approximate Bias" /></p>
<p>A good value for the single frequency $F$ is the mean frequency of the overall signal. We can calculate the mean frequency by taking an average of the frequency components weighted by the power with which they appear in the signal [3]. Note that the zero frequency term is not included as this is already captured by $B$.</p>
\[F = \frac{\sum_{j=1}^K \mathbf{f}_j \mathbf{p}_j}{\sum_{j=1}^K \mathbf{p}_j}.\]
<p>After applying this $F$ with our already found $B$, we get closer to approximating the signal.</p>
<p><img src="/assets/images/PAE/signal_with_F_B.png" alt="Approximate Frequency" /></p>
<p>Now that the bias and frequency are accounted for, we aim to set the amplitude parameter $A$. We set $A$ such that the average power of the signal is maintained. Since we create a sine curve with a single frequency $F$, we set the amplitude such that the power in this frequency bin equals the average power of the signal $\frac{\sum_{j=1}^K \mathbf{p}_j}{N}$. A single sided power spectrum has values at height $\left(\frac{A}{\sqrt{2}}\right)^2$, [2], so rearranging we find,</p>
\[A = \sqrt{\frac{2}{N}\sum_{j=1}^K \mathbf{p}_j}.\]
<p>Adding this value of $A$ to $\Gamma$, our approximation gets closer again.</p>
<p><img src="/assets/images/PAE/signal_with_A_F_B.png" alt="Approximate Amplitude" /></p>
<p>Finally we aim to find $S$. For signals that are not exactly periodic over the time window, we will see discontinuities in the phase extracted from the DFT. To avoid discontinuities in the PAE we learn a 2D phase representation, $(s_x, s_y)$, with a small neural network, from which we calculate $S$ as</p>
\[S = arctan2(s_y, s_x).\]
<p>After including $S$ as the final parameter for $\Gamma(x) = A sin (2 \pi (Fx - S)) + B$, we can see that we have a reasonable approximation of our original signal.</p>
<p><img src="/assets/images/PAE/reconstructed_signal.png" alt="Reconstructed Signal" /></p>
<p>By performing this parameterization process fully differentiably, the PAE encoder is encouraged to extract latent features which remain useful for reconstructing the original data after undergoing this severe dimensionality reduction (to 5 parameters ($F$, $A$, $B$, $s_x$, $s_y$) per latent curve). Intuitively, this means the phase representation must capture a lot of “information” about the current state of the original time series data within the available context.</p>
<hr />
<h1 id="addendum">Addendum</h1>
<p>In the final figure the phase value was actually calculated with the following equations (thanks to <a href="https://mariogeiger.ch" target="_blank">Mario Geiger</a>). You may have some luck replacing the small phase learning network with these operations, but I haven’t had time to check that it is well behaved in practical applications.</p>
\[s_x = \sum_{i=1}^{N}(y_i - B) cos(2 \pi F x_i)\]
\[s_y = \sum_{i=1}^{N}(y_i - B) sin(2 \pi F x_i)\]
\[S = arctan2(s_y, s_x)\]
<p>Where ${(x_i, y_i)}$ are the $N$ samples that make up our signal (blue curve). With these values we find our approximation (orange curve) with:</p>
\[A cos(2 \pi F x - S) + B .\]
<hr />
<h1 id="code">Code</h1>
<p>The figures in this post were generated with this <a href="https://ianxmason.github.io/assets/code/PAE/pae_fft.py" target="_blank">short python script</a>.</p>
<p>There is a basic implementation of the Periodic Autoencoder in the supplementary material files <a href="https://dl.acm.org/doi/10.1145/3528223.3530178" target="_blank">here</a>.</p>
<hr />
<h1 id="references">References</h1>
<p>[1] <a href="https://pytorch.org/docs/stable/generated/torch.fft.rfft.html" target="_blank">PyTorch Differentiable FFT</a></p>
<p>[2] <a href="https://www.sjsu.edu/people/burford.furman/docs/me120/FFT_tutorial_NI.pdf" target="_blank">The fundamentals of FFT-based signal analysis and measurement; Michael Cerna and Audrey F. Harvey; 2000.</a></p>
<p>[3] <a href="https://www.intechopen.com/chapters/40123" target="_blank">The usefulness of mean and median frequencies in electromyography analysis; Angkoon Phinyomark, Sirinee Thongpanja, Huosheng Hu, Pornchai Phukpattaranont and Chusak Limsakul; 2012.</a></p>Ian MasonA brief explanation of the derivation of the PAE