Enhancing LangChain’s RetrievalQA for Real Source Links

Masato Naka
9 min readJan 10, 2024

--

Introduction

RetrievalQA stands as a powerful tool in overcoming one of the significant challenges of LLM-based systems — the limitation of providing answers based solely on learned information up to the training point. This tool serves as a collaborative solution, enabling the extraction of novel information not covered during the model’s training.

In the realm of RetrievalQA, passing a retriever such as VectorStore or custom datasets provides a means to query pertinent data or documents. This, in turn, empowers the LLM to generate responses based on a broader and more up-to-date understanding of the context.

A notable advantage of RetrievalQA lies in its ability to shed light on the origins of LLM’s responses. By indicating the sources from which answers are derived, it becomes possible to assess the credibility of the model’s responses and delve into the underlying information that forms the basis of each answer.

In this article, we’ll delve into a specific challenge encountered in RetrievalQA scenarios, where the default prompts may include irrelevant documents in the final set of source documents. By addressing this issue, we aim to optimize the effectiveness of RetrievalQA, ensuring that the returned sources are used to derive the generated response. Stay tuned for a detailed exploration of the modifications within the RetrievalQA framework to enhance its precision.

Identifying the Issue

In the context of RetrievalQA, specifying return_source_documents=True allows us to retrieve not only the final answer from the Language Model (LLM) but also the documents passed as context to the LLM. However, the challenge arises as it's not clear which specific documents were utilized to derive the ultimate answer. This lack of clarity poses a risk of providing irrelevant sources when presenting both the answer and sources together.

For instance, let’s consider a scenario where the Retriever returns three documents, and these are incorporated into the prompt for generating the final answer: (from the default prompt)

Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

<Document1.page_context>

<Document2.page_content>

<Document3.page_content>

Question: {question}
Helpful Answer:

In cases where Document3, for example, doesn't contribute to the final answer and is irrelevant, returning all the result["source_documents"] to the user is suboptimal.

To address this issue, we want to refine the prompt in a way that excludes irrelevant documents and only includes those directly contributing to the final answer. This ensures a more precise and contextually relevant set of source documents for the user.

In this post, we will primarily focus on adjustments within RetrievalQA to handle such cases. While improvements in the retriever’s document retrieval process are essential, the varying relevance of documents returned by different retrievers makes it crucial to fine-tune the RetrievalQA itself.

Let’s dive into the modifications needed to steer clear of scenarios where irrelevant documents are included in the final source documents.

Prompt Refinement

To enhance the default prompt, the following modifications were made:

prompt_template = """Use the following pieces of context to answer the question at the end. Please follow the following rules:
1. If the question is to request links, please only return the source links with no answer.
2. If you don't know the answer, don't try to make up an answer. Just say **I can't find the final answer but you may want to check the following links** and add the source links as a list.
3. If you find the answer, write the answer in a concise way and add the list of sources that are **directly** used to derive the answer. Exclude the sources that are irrelevant to the final answer.

{context}

Question: {question}
Helpful Answer:"""

PROMPT = PromptTemplate(template=prompt_template)

This new prompt distinguishes between cases where an answer is needed and situations where only links are required. In cases of uncertainty, it provides only the links that were passed to the LLM.

Add Source Information to Metadata

While the prompt instructs the inclusion of sources in the final answer, you can specify any kind of source to return in the final answer such as URLs, titles, keywords etc.

When using a DocumentLoader to generate documents and store them in VectorStore, there are two approaches to adding the necessary data to the metadata. This can be done either on the DocumentLoader side by storing the required data in metadata during document creation or by adding the information to metadata before adding the document to VectorStore.

To determine if the DocumentLoader has the necessary information in its metadata, you can inspect the implementation of the DocumentLoader you are using.

For instance, when using ConfluenceLoader, the metadata is set by default in the process_page method, and the URL is stored in the source field.pythonCopy code

metadata = {
"title": page["title"],
"id": page["id"],
"source": self.base_url.strip("/") + page["_links"]["webui"],
}

In other DocumentLoaders, you can examine the contents of the metadata and, if necessary, modify the implementation to store the metadata yourself. This way, you can include the sources you want in the final answer.

Even without modifying the DocumentLoader, when using VectorStore, you can add the required information to the metadata before adding the document to VectorStore.

docs = [
Document(
page_content=user_msg,
metadata={"source": "https://example.com/sample.html"},
)
]
vector_store.add_documents(docs)

Update document prompt

In RetrievalQA, a DocumentsChain is used to determine how to fill out {content} with the given documents in the final prompt.

prompt_template = """Use the following pieces of context to answer the question at the end. Please follow the following rules:
1. If the question is to request links, please only return the source links with no answer.
2. If you don't know the answer, don't try to make up an answer. Just say **I can't find the final answer but you may want to check the following links** and add the source links as a list.
3. If you find the answer, write the answer in a concise way and add the list of sources that are **directly** used to derive the answer. Exclude the sources that are irrelevant to the final answer.

{context}

Question: {question}
Helpful Answer:"""

In this post, we use StuffDocumentsChain, which is the default DocumentsChain in RetrievalQA, as other documents chains don’t directly pass the original documents to LLM, there’s no way to determine which documents contributed to derive the final answer.

document_prompt determines how a list of Documents are embedded into the prompt in the following way (StuffDocumentsChain):

  1. For each of the given documents, call format_document with the document_promt , which just extract page_content from the document by default, and store them in the list. In short, doc_strings is a list of page_content of the documents.
  2. Concatenate each element of the list with the document_separator, the default value of which is \n\n.
class StuffDocumentsChain(BaseCombineDocumentsChain):
...
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
...
# Format each document according to the prompt
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
return inputs

The default value for document_separator, documents_key, document_prompt are defined here:

DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
DOCUMENTS_KEY = "context"
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")

Simply input, {context} consists of page_content of the documents with two new lines as a separator so the final prompt would be the following as we saw above:

Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

<Document1.page_context>

<Document2.page_content>

<Document3.page_content>

Question: {question}
Helpful Answer:

Now we also want to pass sources to LLM so LLM can answer which source is used to derive the answer. We can modify document_prompt in the following manner:

        document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.

If you want to add source in metadata, you can use the following custom template. Note that all the documents must contain the source in the metadata.

    document_prompt = PromptTemplate(
input_variables=["page_content", "source"],
template="Context:\ncontent:{page_content}\nsource:{source}",
)

Put everything together

Finally, we can initialize RetrievalQA with the custom prompt, custom prompt template and your retriever:

from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain.schema.document import Document
from typing import List


class CustomRetriever(BaseRetriever):
"""Always return three static documents for testing."""

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return [
Document(page_content="Japan has a population of 126 million people.", metadata={"source": "https://en.wikipedia.org/wiki/Japan"}),
Document(page_content="Japanese people are very polite.", metadata={"source": "https://en.wikipedia.org/wiki/Japanese_people"}),
Document(page_content="United States has a population of 328 million people.", metadata={"source": "https://en.wikipedia.org/wiki/United_States"}),
]

prompt_template = """Use the following pieces of context to answer the question at the end. Please follow the following rules:
1. If the question is to request links, please only return the source links with no answer.
2. If you don't know the answer, don't try to make up an answer. Just say **I can't find the final answer but you may want to check the following links** and add the source links as a list.
3. If you find the answer, write the answer in a concise way and add the list of sources that are **directly** used to derive the answer. Exclude the sources that are irrelevant to the final answer.

{context}

Question: {question}
Helpful Answer:"""


def main():
retriever = CustomRetriever()
QA_CHAIN_PROMPT = PromptTemplate.from_template(prompt_template) # prompt_template defined above
llm_chain = LLMChain(llm=ChatOpenAI(), prompt=QA_CHAIN_PROMPT, callbacks=None, verbose=True)
document_prompt = PromptTemplate(
input_variables=["page_content", "source"],
template="Context:\ncontent:{page_content}\nsource:{source}",
)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
callbacks=None,
)
qa = RetrievalQA(
combine_documents_chain=combine_documents_chain,
callbacks=None,
verbose=True,
retriever=retriever,
return_source_documents=True,
)
res = qa("How many people live in Japan?")
print(res['result'])
res = qa("How many people live in US?")
print(res['result'])
res = qa("How many people live in Singapore?")
print(res['result'])

if __name__ == "__main__":
main()

With this customized RetrievalQA, you’ll get answer with the sources that are actually used to derive the final answer.

Result:

Japan (Known case):

Japan has a population of 126 million people. 
Sources:
- https://en.wikipedia.org/wiki/Japan

It answered with the link used to derive the answer! You can see the similar results for US too.

Singapore (Unknown case):

The custom retriever just has information about Japan and US. When you ask about Singapore, you want the RetrievalQA to answer “no answer found” as specified in the prompt.

I can't find the final answer but you may want to check the following links:
1. https://en.wikipedia.org/wiki/Singapore

Comparison with the normal RetrievalQA

You can run normal Retrieval by replacing the RetrievalQA with the following code:

qa = RetrievalQA.from_llm(llm=ChatOpenAI(), retriever=retriever, return_source_documents=True, verbose=True)
> Entering new RetrievalQA chain...

> Finished chain.
Japan has a population of 126 million people. [Document(page_content='Japan has a population of 126 million people.', metadata={'source': 'https://en.wikipedia.org/wiki/Japan'}), Document(page_content='Japanese people are very polite.', metadata={'source': 'https://en.wikipedia.org/wiki/Japanese_people'}), Document(page_content='United States has a population of 328 million people.', metadata={'source': 'https://en.wikipedia.org/wiki/United_States'})]


> Entering new RetrievalQA chain...

> Finished chain.
The population of the United States is approximately 328 million people. [Document(page_content='Japan has a population of 126 million people.', metadata={'source': 'https://en.wikipedia.org/wiki/Japan'}), Document(page_content='Japanese people are very polite.', metadata={'source': 'https://en.wikipedia.org/wiki/Japanese_people'}), Document(page_content='United States has a population of 328 million people.', metadata={'source': 'https://en.wikipedia.org/wiki/United_States'})]


> Entering new RetrievalQA chain...

> Finished chain.
Singapore has a population of approximately 5.7 million people. [Document(page_content='Japan has a population of 126 million people.', metadata={'source': 'https://en.wikipedia.org/wiki/Japan'}), Document(page_content='Japanese people are very polite.', metadata={'source': 'https://en.wikipedia.org/wiki/Japanese_people'}), Document(page_content='United States has a population of 328 million people.', metadata={'source': 'https://en.wikipedia.org/wiki/United_States'})]

While utilizing return_source_documents=True in a standard RetrievalQA, all documents are visible, but pinpointing the exact source used to derive the final answer remains elusive. The Language Model (LLM) might generate an answer for information—like the population of Singapore—not originally included in the provided documents. This poses a significant drawback, especially if the objective is to constrain information retrieval solely to the documents at hand. The inability to identify the specific document contributing to the answer compromises the precision and control over the sourced information."

Summary

In the realm of RetrievalQA, the ambiguity surrounding which documents contribute to the final answer led us to implement a solution in three key steps.

  1. First, we customized the prompt to explicitly return the sources of the documents used to derive the final answer.
  2. Second, we introduced methods to store the desired information in the metadata of the documents, either during document creation in the DocumentLoader or post-acquisition.
  3. Lastly, we tailored the document prompt in the StuffDocumentsChain to include the information we wanted as sources in the context.

This enhancement addresses the inherent limitation of not knowing which documents are actively used in generating the final answer even when using return_source_documents=True.

To illustrate the impact of these modifications, we concluded with an example utilizing example source code, highlighting the distinction between a normal RetrievalQA and our customized RetrievalQA. This transformative approach ensures not just information retrieval but precision, marking a significant stride in maximizing the potential of language models within the RetrievalQA framework.

--

--

Masato Naka

An SRE engineer, mainly working on Kubernetes. CKA (Feb 2021). His Interests include Cloud-Native application development, and machine learning.