Local GraphRAG + Langchain + GPT+4o = Easy AI/Chat for your Docs
In this story, I have a super quick tutorial showing you how to create an AI for your PDF with local GraphRag, Langchain, and local LLM to make a powerful Agent Chatbot for your business or personal
In this story, I have a super quick tutorial showing you how to create an AI for your PDF with local GraphRag, Langchain, and local LLM to make a powerful Agent Chatbot for your business or personal use.
GraphRAG is an RAG that combines knowledge graph and generative AI technology to answer inquiries that were difficult to handle with conventional RAGs.
In my last video of Graphrag, we used API to develop a knowledge graph, but this time, we are creating an advanced algorithm in which we have full control of our data. That sounds great.
But here’s the kicker: with this method, you can completely tweak and customize your app to fit your needs perfectly.
I wanted to create this story to walk you through the most straightforward method I found so you can get things set up and improve your AI chatbot Knowledge.
Let me give you a quick demo of a live chatbot to show you what I mean.
Let me give you a simple example: What is RagChecker? If you haven’t seen this video, I highly recommend it. RagChecker is way better than Ragas. If you look at how the local Raggraph generates the output, you will see that when a user asks a question, it converts the question into a vector and searches through the stored vectors to find the most relevant chunks of information. We use Dijkstra’s algorithm, which finds the shortest path between points, to explore the knowledge graph. We start with nodes most closely related to the user’s question.
As it explores each node it adds the information to a growing context. If this context doesn’t fully answer the question, it looks at neighbouring nodes and updates their importance based on connection strength. It continues this process until it finds a complete answer. If it can’t find a complete answer from the graph, it uses a large language model to generate one based on the accumulated context. The knowledge graph is then visualized with nodes representing the chunks of text and edges showing how they’re related. The edges are coloured light blue to indicate the strength of the connections. The path used to find the answer is highlighted with curved, dashed arrows, and the start and end points of this path are marked in green and red.
Before we start! 🦸🏻♀️
If you like this topic and you want to support me:
like my article; that will really help me out.👏
Follow me on my YouTube channel
Subscribe to me to get the latest article.
Now, let’s get on with the guide on how to build an AI chatbot using local Graphrag When it came to creating this chatbot, I had many options. I spent a few weeks upskilling and exploring the various technologies available. I have become familiar with Langchain, spacy, sklearn, Natural Language Toolkit and other libraries that allow us to easily define and interact with different types of abstractions, which makes it easy to build powerful chatbots.
import networkx as nx
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_community.callbacks import get_openai_callback
from langchain_community.document_loaders import PyPDFLoader
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import sys
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_ollama.chat_models import ChatOllama
from typing import List, Tuple, Dict
from nltk.stem import WordNetLemmatizer
from langchain_core.pydantic_v1 import BaseModel, Field
import nltk
import spacy
import heapq
from langchain_openai import OpenAIEmbeddings
import streamlit as st
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import numpy as np
from streamlit_chat import message
from spacy.cli import download
I define a document processor that handles documents by splitting them into smaller parts, creating embeddings, and then comparing these embeddings to find similarities.
We make a function to process the document. Then, it creates a “vector store” where each chunk of text is stored along with its embedding, which makes it easier to search and compare different chunks later. Sometimes, you have a lot of text to process, and it is slow to do it all at once.
It helps break the text into smaller batches, creating embeddings for each batch, and combining them into one extensive list. It then compares all those embeddings to see how similar they are to each other. I use cosine similarity, a fancy way of measuring how close two things are in a mathematical sense.
# Define the DocumentProcessor class
class DocumentProcessor:
def __init__(self):
"""
Initializes the DocumentProcessor with a text splitter and OpenAI embeddings.
Attributes:
- text_splitter: An instance of RecursiveCharacterTextSplitter with specified chunk size and overlap.
- embeddings: An instance of OpenAIEmbeddings used for embedding documents.
"""
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
self.embeddings = OpenAIEmbeddings()
def process_documents(self, documents):
"""
Processes a list of documents by splitting them into smaller chunks and creating a vector store.
Args:
- documents (list of str): A list of documents to be processed.
Returns:
- tuple: A tuple containing:
- splits (list of str): The list of split document chunks.
- vector_store (FAISS): A FAISS vector store created from the split document chunks and their embeddings.
"""
splits = self.text_splitter.split_documents(documents)
vector_store = FAISS.from_documents(splits, self.embeddings)
return splits, vector_store
def create_embeddings_batch(self, texts, batch_size=32):
"""
Creates embeddings for a list of texts in batches.
Args:
- texts (list of str): A list of texts to be embedded.
- batch_size (int, optional): The number of texts to process in each batch. Default is 32.
Returns:
- numpy.ndarray: An array of embeddings for the input texts.
"""
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings = self.embeddings.embed_documents(batch)
embeddings.extend(batch_embeddings)
return np.array(embeddings)
def compute_similarity_matrix(self, embeddings):
"""
Computes a cosine similarity matrix for a given set of embeddings.
Args:
- embeddings (numpy.ndarray): An array of embeddings.
Returns:
- numpy.ndarray: A cosine similarity matrix for the input embeddings.
"""
return cosine_similarity(embeddings)
I then create a knowledge graph to organize information and show how different ideas are connected. Imagine it as building a map where each idea or piece of information is a point (a node), and lines (edges) connect these points if the ideas are similar or related. At the core of this class is a graph structure built with networkx.
I also use a lemmatizer, simplifying words to their base form, ensuring that similar concepts are accurately identified and compared. The concept_cache acts as a memory bank, storing previously identified concepts to prevent redundant processing, which speeds up the operation. The edges_threshold determines how similar two concepts need to be before they are connected in the graph.
Next, I made a build_graph function to integrate all these components. It adds each document split as a node in the graph, creates embeddings to help the system analyze and identify similarities, uses a large language model to extract key concepts from each split, and finally, connects nodes with edges if their concepts are sufficiently similar.
# Define the KnowledgeGraph class
class KnowledgeGraph:
def __init__(self):
"""
Initializes the KnowledgeGraph with a graph, lemmatizer, and NLP model.
Attributes:
- graph: An instance of a networkx Graph.
- lemmatizer: An instance of WordNetLemmatizer.
- concept_cache: A dictionary to cache extracted concepts.
- nlp: An instance of a spaCy NLP model.
- edges_threshold: A float value that sets the threshold for adding edges based on similarity.
"""
self.graph = nx.Graph()
self.lemmatizer = WordNetLemmatizer()
self.concept_cache = {}
self.nlp = self._load_spacy_model()
self.edges_threshold = 0.8
def build_graph(self, splits, llm, embedding_model):
"""
Builds the knowledge graph by adding nodes, creating embeddings, extracting concepts, and adding edges.
Args:
- splits (list): A list of document splits.
- llm: An instance of a large language model.
- embedding_model: An instance of an embedding model.
Returns:
- None
"""
self._add_nodes(splits)
embeddings = self._create_embeddings(splits, embedding_model)
self._extract_concepts(splits, llm)
self._add_edges(embeddings)
def _add_nodes(self, splits):
"""
Adds nodes to the graph from the document splits.
Args:
- splits (list): A list of document splits.
Returns:
- None
"""
for i, split in enumerate(splits):
self.graph.add_node(i, content=split.page_content)
in this part, I create the create_embedding function that takes document sections and creates embeddings for each document. We collect the text from each section and use the embedding model to convert these texts into numerical embeddings.
It then returns an array of these embeddings, with each entry representing a document section. Next, we create a compute function to calculate the similarity between these embeddings. Measuring the cosine similarity determines how alike different sections of the document are.
Cosine similarity compares two sets of numbers or vectors, with values close to 1 indicating high similarity. Lastly, I use the Spacy model function to load a specific NLP model from spaCy, if necessary, is a tool used for understanding and processing text.
def _create_embeddings(self, splits, embedding_model):
"""
Creates embeddings for the document splits using the embedding model.
Args:
- splits (list): A list of document splits.
- embedding_model: An instance of an embedding model.
Returns:
- numpy.ndarray: An array of embeddings for the document splits.
"""
texts = [split.page_content for split in splits]
return embedding_model.embed_documents(texts)
def _compute_similarities(self, embeddings):
"""
Computes the cosine similarity matrix for the embeddings.
Args:
- embeddings (numpy.ndarray): An array of embeddings.
Returns:
- numpy.ndarray: A cosine similarity matrix for the embeddings.
"""
return cosine_similarity(embeddings)
def _load_spacy_model(self):
"""
Loads the spaCy NLP model, downloading it if necessary.
Args:
- None
Returns:
- spacy.Language: An instance of a spaCy NLP model.
"""
try:
return spacy.load("en_core_web_sm")
except OSError:
print("Downloading spaCy model...")
download("en_core_web_sm")
return spacy.load("en_core_web_sm")
Let’s create the extract_concepts_and_entities function, which is responsible for extracting important concepts and named entities — such as names of people or places — from a given piece of text. It uses both spaCy and a large language model (LLM) to achieve this. It checks if the content has already been processed and cached to avoid redundant work. If not, it uses spaCy to find named entities.
Then, we use the _extract_concepts function to extract concepts from multiple document sections simultaneously to speed up the process. It used ThreadPoolExecutor to handle various tasks simultaneously, where each task involves calling the _extract_concepts_and_entities function for a different document section. As each task is completed, it updates the graph with the concepts found, using tqdm to show the progress.
Next, we make _add_edges, which add connections, or edges, between document sections in a graph based on their similarity and shared concepts. It calculates how similar each section is to every other section using embeddings.
If their similarity score exceeds a certain threshold for each pair of sections, it finds shared concepts and calculates an edge weight based on both similarity and shared concepts. Finally, it updates the graph with these connections, detailing the similarity score, edge weight, and shared concepts while using tqdm to track progress.
def _extract_concepts_and_entities(self, content, llm):
"""
Extracts concepts and named entities from the content using spaCy and a large language model.
Args:
- content (str): The content from which to extract concepts and entities.
- llm: An instance of a large language model.
Returns:
- list: A list of extracted concepts and entities.
"""
if content in self.concept_cache:
return self.concept_cache[content]
# Extract named entities using spaCy
doc = self.nlp(content)
named_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "WORK_OF_ART"]]
# Extract general concepts using LLM
concept_extraction_prompt = PromptTemplate(
input_variables=["text"],
template="Extract key concepts (excluding named entities) from the following text:\n\n{text}\n\nKey concepts:"
)
concept_chain = concept_extraction_prompt | llm.with_structured_output(Concepts)
general_concepts = concept_chain.invoke({"text": content}).concepts_list
# Combine named entities and general concepts
all_concepts = list(set(named_entities + general_concepts))
self.concept_cache[content] = all_concepts
return all_concepts
def _extract_concepts(self, splits, llm):
"""
Extracts concepts for all document splits using multi-threading.
Args:
- splits (list): A list of document splits.
- llm: An instance of a large language model.
Returns:
- None
"""
with ThreadPoolExecutor() as executor:
future_to_node = {executor.submit(self._extract_concepts_and_entities, split.page_content, llm): i
for i, split in enumerate(splits)}
for future in tqdm(as_completed(future_to_node), total=len(splits), desc="Extracting concepts and entities"):
node = future_to_node[future]
concepts = future.result()
self.graph.nodes[node]['concepts'] = concepts
def _add_edges(self, embeddings):
"""
Adds edges to the graph based on the similarity of embeddings and shared concepts.
Args:
- embeddings (numpy.ndarray): An array of embeddings for the document splits.
Returns:
- None
"""
similarity_matrix = self._compute_similarities(embeddings)
num_nodes = len(self.graph.nodes)
for node1 in tqdm(range(num_nodes), desc="Adding edges"):
for node2 in range(node1 + 1, num_nodes):
similarity_score = similarity_matrix[node1][node2]
if similarity_score > self.edges_threshold:
edge_weight = self._calculate_edge_weight(node1, node2, similarity_score, shared_concepts)
self.graph.add_edge(node1, node2, weight=edge_weight,
similarity=similarity_score,
shared_concepts=list(shared_concepts))
then we create an edge weight function to calculate the maximum possible shared concepts and the proportion of actual shared concepts. It combines these with a similarity score, using alpha and beta to weigh their importance.
The final edge weight indicates how strong the link is between the two sections. Then, the lemmatize_concept function simplifies a concept by converting it to lowercase, splitting it into words, and reducing each word to its root form using a lemmatizer. It then combines these root words back into a single string.
def _calculate_edge_weight(self, node1, node2, similarity_score, shared_concepts, alpha=0.7, beta=0.3):
"""
Calculates the weight of an edge based on similarity score and shared concepts.
Args:
- node1 (int): The first node.
- node2 (int): The second node.
- similarity_score (float): The similarity score between the nodes.
- shared_concepts (set): The set of shared concepts between the nodes.
- alpha (float, optional): The weight of the similarity score. Default is 0.7.
- beta (float, optional): The weight of the shared concepts. Default is 0.3.
Returns:
- float: The calculated weight of the edge.
"""
max_possible_shared = min(len(self.graph.nodes[node1]['concepts']), len(self.graph.nodes[node2]['concepts']))
normalized_shared_concepts = len(shared_concepts) / max_possible_shared if max_possible_shared > 0 else 0
return alpha * similarity_score + beta * normalized_shared_concepts
def _lemmatize_concept(self, concept):
"""
Lemmatizes a given concept.
Args:
- concept (str): The concept to be lemmatized.
Returns:
- str: The lemmatized concept.
"""
return ' '.join([self.lemmatizer.lemmatize(word) for word in concept.lower().split()])
We define the QueryEngine class to help answer questions using various information sources. It relies on three main components: vector_store, which holds numerical representations of text; knowledge_graph, a structure for storing and connecting information about different topics; and a large language model that helps understand and generate text.
I initialize the QueryEngine, storing these components and setting a limit on how much context it can handle simultaneously (max_context_length). It also creates a specialized tool, answer_check_chain, to verify the completeness of answers.
The _create_answer_check_chain method sets up a system to determine whether the provided context fully answers a query through a prompt that asks if the context completes the answer.
The _check_answer uses this chain to evaluate whether the provided context answers the query. It sends the query and context to the chain, which then assesses if the context fully answers the question and returns a tuple indicating whether the answer is complete.
# Define the QueryEngine class
class QueryEngine:
def __init__(self, vector_store, knowledge_graph, llm):
self.vector_store = vector_store
self.knowledge_graph = knowledge_graph
self.llm = llm
self.max_context_length = 4000
self.answer_check_chain = self._create_answer_check_chain()
def _create_answer_check_chain(self):
"""
Creates a chain to check if the context provides a complete answer to the query.
Args:
- None
Returns:
- Chain: A chain to check if the context provides a complete answer.
"""
answer_check_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Given the query: '{query}'\n\nAnd the current context:\n{context}\n\nDoes this context provide a complete answer to the query? If yes, provide the answer. If no, state that the answer is incomplete.\n\nIs complete answer (Yes/No):\nAnswer (if complete):"
)
return answer_check_prompt | self.llm.with_structured_output(AnswerCheck)
def _check_answer(self, query: str, context: str) -> Tuple[bool, str]:
"""
Checks if the current context provides a complete answer to the query.
Args:
- query (str): The query to be answered.
- context (str): The current context.
Returns:
- tuple: A tuple containing:
- is_complete (bool): Whether the context provides a complete answer.
- answer (str): The answer based on the context, if complete.
"""
response = self.answer_check_chain.invoke({"query": query, "context": context})
return response.is_complete, response.answer
So, let’s make the _expand_context function to improve query understanding through the exploration of a knowledge graph and the gathering of additional information from related documents.
It starts with setting up some key variables to keep track of the expanded context, checking the node, and any answers or concepts it has found. The function uses a priority queue to see and explore the most similar nodes to the initial documents, adding them to the queue based on how strongly they are connected. It then looks at the nodes with the highest priority, adds their content to the context, and checks if this new information answers the query.
If the answer isn’t complete, it continues to explore neighboring nodes that bring in new concepts and update their priorities. If it still doesn’t find a full answer, it uses a large language model (LLM) to create an answer based on all the gathered context. In summary, _expand_context improves understanding through careful exploration and expansion of relevant information in the knowledge graph.
def _expand_context(self, query: str, relevant_docs) -> Tuple[str, List[int], Dict[int, str], str]:
"""
Expands the context by traversing the knowledge graph using a Dijkstra-like approach.
This method implements a modified version of Dijkstra's algorithm to explore the knowledge graph,
prioritizing the most relevant and strongly connected information. The algorithm works as follows:
1. Initialize:
- Start with nodes corresponding to the most relevant documents.
- Use a priority queue to manage the traversal order, where priority is based on connection strength.
- Maintain a dictionary of best known "distances" (inverse of connection strengths) to each node.
2. Traverse:
- Always explore the node with the highest priority (strongest connection) next.
- For each node, check if we've found a complete answer.
- Explore the node's neighbors, updating their priorities if a stronger connection is found.
3. Concept Handling:
- Track visited concepts to guide the exploration towards new, relevant information.
- Expand to neighbors only if they introduce new concepts.
4. Termination:
- Stop if a complete answer is found.
- Continue until the priority queue is empty (all reachable nodes explored).
This approach ensures that:
- We prioritize the most relevant and strongly connected information.
- We explore new concepts systematically.
- We find the most relevant answer by following the strongest connections in the knowledge graph.
Args:
- query (str): The query to be answered.
- relevant_docs (List[Document]): A list of relevant documents to start the traversal.
Returns:
- tuple: A tuple containing:
- expanded_context (str): The accumulated context from traversed nodes.
- traversal_path (List[int]): The sequence of node indices visited.
- filtered_content (Dict[int, str]): A mapping of node indices to their content.
- final_answer (str): The final answer found, if any.
"""
# Initialize variables
expanded_context = ""
traversal_path = []
visited_concepts = set()
filtered_content = {}
final_answer = ""
priority_queue = []
distances = {} # Stores the best known "distance" (inverse of connection strength) to each node
print("\nTraversing the knowledge graph:")
# Initialize priority queue with closest nodes from relevant docs
for doc in relevant_docs:
# Find the most similar node in the knowledge graph for each relevant document
closest_nodes = self.vector_store.similarity_search_with_score(doc.page_content, k=1)
closest_node_content, similarity_score = closest_nodes[0]
# Get the corresponding node in our knowledge graph
closest_node = next(n for n in self.knowledge_graph.graph.nodes if self.knowledge_graph.graph.nodes[n]['content'] == closest_node_content.page_content)
# Initialize priority (inverse of similarity score for min-heap behavior)
priority = 1 / similarity_score
heapq.heappush(priority_queue, (priority, closest_node))
distances[closest_node] = priority
step = 0
while priority_queue:
# Get the node with the highest priority (lowest distance value)
current_priority, current_node = heapq.heappop(priority_queue)
# Skip if we've already found a better path to this node
if current_priority > distances.get(current_node, float('inf')):
continue
if current_node not in traversal_path:
step += 1
traversal_path.append(current_node)
node_content = self.knowledge_graph.graph.nodes[current_node]['content']
node_concepts = self.knowledge_graph.graph.nodes[current_node]['concepts']
# Add node content to our accumulated context
filtered_content[current_node] = node_content
expanded_context += "\n" + node_content if expanded_context else node_content
# Log the current step for debugging and visualization
st.write(f"<span style='color:red;'>Step {step} - Node {current_node}:</span>", unsafe_allow_html=True)
st.write(f"Content: {node_content[:100]}...")
st.write(f"Concepts: {', '.join(node_concepts)}")
print("-" * 50)
# Explore neighbors
for neighbor in self.knowledge_graph.graph.neighbors(current_node):
edge_data = self.knowledge_graph.graph[current_node][neighbor]
edge_weight = edge_data['weight']
# Calculate new distance (priority) to the neighbor
# Note: We use 1 / edge_weight because higher weights mean stronger connections
distance = current_priority + (1 / edge_weight)
# If we've found a stronger connection to the neighbor, update its distance
if distance < distances.get(neighbor, float('inf')):
distances[neighbor] = distance
heapq.heappush(priority_queue, (distance, neighbor))
# Process the neighbor node if it's not already in our traversal path
if neighbor not in traversal_path:
step += 1
traversal_path.append(neighbor)
neighbor_content = self.knowledge_graph.graph.nodes[neighbor]['content']
neighbor_concepts = self.knowledge_graph.graph.nodes[neighbor]['concepts']
filtered_content[neighbor] = neighbor_content
expanded_context += "\n" + neighbor_content if expanded_context else neighbor_content
# Log the neighbor node information
st.write(f"<span style='color:red;'>Step {step} - Node {neighbor} (neighbor of {current_node}):</span>", unsafe_allow_html=True)
st.write(f"Content: {neighbor_content[:100]}...")
print(f"Concepts: {', '.join(neighbor_concepts)}")
print("-" * 50)
# Check if we have a complete answer after adding the neighbor's content
is_complete, answer = self._check_answer(query, expanded_context)
if is_complete:
final_answer = answer
break
# Process the neighbor's concepts
neighbor_concepts_set = set(self.knowledge_graph._lemmatize_concept(c) for c in neighbor_concepts)
if not neighbor_concepts_set.issubset(visited_concepts):
visited_concepts.update(neighbor_concepts_set)
# If we found a final answer, break out of the main loop
if final_answer:
break
# If we haven't found a complete answer, generate one using the LLM
if not final_answer:
print("\nGenerating final answer...")
response_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
)
response_chain = response_prompt | self.llm
input_data = {"query": query, "context": expanded_context}
final_answer = response_chain.invoke(input_data)
return expanded_context, traversal_path, filtered_content, final_answer
We make the query function handle the entire process of answering a query. It starts using the _retrieve_relevant_documents function to fetch documents related to the query. Once it has these documents, it expands the context to understand the query better and looks for an answer.
The function tracks the path through the knowledge graph and filters the content to focus on what’s important. If it doesn’t find an answer during this process, it uses a language model to generate a final answer based on the expanded context. It prints information about the tokens used and the cost of the operation. It then returns the final answer, the path through the knowledge graph, and the filtered content.
The _retrieve_relevant_documents function, on the other hand, focuses on finding relevant documents. It uses a vector store to compare the query with stored documents and identify the most pertinent ones. It refines these documents using a compression tool to provide a more focused context before returning the list of relevant documents.
def query(self, query: str) -> Tuple[str, List[int], Dict[int, str]]:
"""
Processes a query by retrieving relevant documents, expanding the context, and generating the final answer.
Args:
- query (str): The query to be answered.
Returns:
- tuple: A tuple containing:
- final_answer (str): The final answer to the query.
- traversal_path (list): The traversal path of nodes in the knowledge graph.
- filtered_content (dict): The filtered content of nodes.
"""
with get_openai_callback() as cb:
st.write(f"\nProcessing query: {query}")
relevant_docs = self._retrieve_relevant_documents(query)
expanded_context, traversal_path, filtered_content, final_answer = self._expand_context(query, relevant_docs)
if not final_answer:
st.write("\nGenerating final answer...")
response_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Based on the following context, please answer the query.\n\nContext: {context}\n\nQuery: {query}\n\nAnswer:"
)
response_chain = response_prompt | self.llm
input_data = {"query": query, "context": expanded_context}
response = response_chain.invoke(input_data)
final_answer = response
else:
st.write("\nComplete answer found during traversal.")
st.write(f"\nFinal Answer: {final_answer}")
print(f"\nTotal Tokens: {cb.total_tokens}")
print(f"Prompt Tokens: {cb.prompt_tokens}")
print(f"Completion Tokens: {cb.completion_tokens}")
print(f"Total Cost (USD): ${cb.total_cost}")
return final_answer, traversal_path, filtered_content
def _retrieve_relevant_documents(self, query: str):
"""
Retrieves relevant documents based on the query using the vector store.
Args:
- query (str): The query to be answered.
Returns:
- list: A list of relevant documents.
"""
print("\nRetrieving relevant documents...")
retriever = self.vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
compressor = LLMChainExtractor.from_llm(self.llm)
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
return compression_retriever.invoke(query)
Then, I make The visualize_traversal function in the Visualizer class to create a detailed visual map of a graph and highlight a specific path through it. It starts with creating a new graph, which copies nodes and edges from the original graph. Then, it makes a large plot area using Matplotlib to draw the graph.
The function calculates the positions of each node using a layout algorithm and draws the edges in colors that represent their weights, with heavier edges shown in different colors. The nodes are displayed in light blue.
The function uses red curved arrows to highlight the path through the graph, calculated to ensure smooth curves between nodes. It labels each node with its position in the path and any associated concepts, placing these labels on the plot. The start and end nodes of the path are highlighted in light green and light coral, respectively, to make them easy to identify.
Finally, the function displays this interactive and informative plot using Streamlit, making it easy to understand complex graph data.
# Define the Visualizer class
class Visualizer:
@staticmethod
def visualize_traversal(graph, traversal_path):
traversal_graph = nx.DiGraph()
# Add nodes and edges from the original graph
for node in graph.nodes():
traversal_graph.add_node(node)
for u, v, data in graph.edges(data=True):
traversal_graph.add_edge(u, v, **data)
fig, ax = plt.subplots(figsize=(16, 12))
# Generate positions for all nodes
pos = nx.spring_layout(traversal_graph, k=1, iterations=50)
# Draw regular edges with color based on weight
edges = traversal_graph.edges()
edge_weights = [traversal_graph[u][v].get('weight', 0.5) for u, v in edges]
nx.draw_networkx_edges(traversal_graph, pos,
edgelist=edges,
edge_color=edge_weights,
edge_cmap=plt.cm.Blues,
width=2,
ax=ax)
# Draw nodes
nx.draw_networkx_nodes(traversal_graph, pos,
node_color='lightblue',
node_size=3000,
ax=ax)
# Draw traversal path with curved arrows
edge_offset = 0.1
for i in range(len(traversal_path) - 1):
start = traversal_path[i]
end = traversal_path[i + 1]
start_pos = pos[start]
end_pos = pos[end]
# Calculate control point for curve
mid_point = ((start_pos[0] + end_pos[0]) / 2, (start_pos[1] + end_pos[1]) / 2)
control_point = (mid_point[0] + edge_offset, mid_point[1] + edge_offset)
# Draw curved arrow
arrow = patches.FancyArrowPatch(start_pos, end_pos,
connectionstyle=f"arc3,rad={0.3}",
color='red',
arrowstyle="->",
mutation_scale=20,
linestyle='--',
linewidth=2,
zorder=4)
ax.add_patch(arrow)
# Prepare labels for the nodes
labels = {}
for i, node in enumerate(traversal_path):
concepts = graph.nodes[node].get('concepts', [])
label = f"{i + 1}. {concepts[0] if concepts else ''}"
labels[node] = label
for node in traversal_graph.nodes():
if node not in labels:
concepts = graph.nodes[node].get('concepts', [])
labels[node] = concepts[0] if concepts else ''
# Draw labels
nx.draw_networkx_labels(traversal_graph, pos, labels, font_size=8, font_weight="bold", ax=ax)
# Highlight start and end nodes
start_node = traversal_path[0]
end_node = traversal_path[-1]
nx.draw_networkx_nodes(traversal_graph, pos,
nodelist=[start_node],
node_color='lightgreen',
node_size=3000,
ax=ax)
nx.draw_networkx_nodes(traversal_graph, pos,
nodelist=[end_node],
node_color='lightcoral',
node_size=3000,
ax=ax)
ax.set_title("Graph Traversal Flow")
ax.axis('off')
# Add colorbar for edge weights
sm = plt.cm.ScalarMappable(cmap=plt.cm.Blues, norm=plt.Normalize(vmin=min(edge_weights), vmax=max(edge_weights)))
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)
cbar.set_label('Edge Weight', rotation=270, labelpad=15)
# Add legend
regular_line = plt.Line2D([0], [0], color='blue', linewidth=2, label='Regular Edge')
traversal_line = plt.Line2D([0], [0], color='red', linewidth=2, linestyle='--', label='Traversal Path')
start_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen', markersize=15, label='Start Node')
end_point = plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral', markersize=15, label='End Node')
legend = plt.legend(handles=[regular_line, traversal_line, start_point, end_point], loc='upper left', bbox_to_anchor=(0, 1), ncol=2)
legend.get_frame().set_alpha(0.8)
plt.tight_layout()
# Streamlit display
st.pyplot(fig)
We created The GraphRAG class is designed to handle document processing and querying in a sophisticated way.
When you create an instance, it initializes all the necessary tools: a large language model to generate responses, an embedding model to convert text into numerical formats, a document processor to split and process documents, and a knowledge graph to organize and connect information.
When processing documents, the class breaks them into chunks, converts these chunks into embeddings, and then builds a graph that connects all the pieces of information. Once the documents are processed, we can ask questions, and the class will search the graph for relevant information. It shows how it found the answer and provides the response.
class GraphRAG:
def __init__(self):
"""
Initializes the GraphRAG system with components for document processing, knowledge graph construction,
querying, and visualization.
Attributes:
- llm: An instance of a large language model (LLM) for generating responses.
- embedding_model: An instance of an embedding model for document embeddings.
- document_processor: An instance of the DocumentProcessor class for processing documents.
- knowledge_graph: An instance of the KnowledgeGraph class for building and managing the knowledge graph.
- query_engine: An instance of the QueryEngine class for handling queries (initialized as None).
- visualizer: An instance of the Visualizer class for visualizing the knowledge graph traversal.
"""
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
self.embedding_model = OpenAIEmbeddings()
self.document_processor = DocumentProcessor()
self.knowledge_graph = KnowledgeGraph()
self.query_engine = None
self.visualizer = Visualizer()
def process_documents(self, documents):
"""
Processes a list of documents by splitting them into chunks, embedding them, and building a knowledge graph.
Args:
- documents (list of str): A list of documents to be processed.
Returns:
- None
"""
splits, vector_store = self.document_processor.process_documents(documents)
self.knowledge_graph.build_graph(splits, self.llm, self.embedding_model)
self.query_engine = QueryEngine(vector_store, self.knowledge_graph, self.llm)
def query(self, query: str):
"""
Handles a query by retrieving relevant information from the knowledge graph and visualizing the traversal path.
Args:
- query (str): The query to be answered.
Returns:
- str: The response to the query.
"""
response, traversal_path, filtered_content = self.query_engine.query(query)
if traversal_path:
self.visualizer.visualize_traversal(self.knowledge_graph.graph, traversal_path)
else:
print("No traversal path to visualize.")
return response
Finally, we create The main function that sets up a Streamlit app that lets users interact with a PDF document through a chat interface.
Users can upload their PDF files. Once a file is uploaded, the app temporarily saves it and processes it. If the PDF is processed successfully, the app initializes chat history with a welcome message and sets up containers for displaying chat messages and input forms.
Users can enter their questions about the PDF in a text input box. When they submit a query, the app uses GraphRAG to process the document and respond.
def main():
# Streamlit setup
st.title("Chat with PDF using local RagGraph 🕸️🦜")
# Load PDF and process documents
if 'ready' not in st.session_state:
st.session_state['ready'] = False
uploaded_file = st.file_uploader("Upload your PDF here 👇:", type="pdf")
if uploaded_file is not None:
with st.spinner("Processing..."):
# Save the uploaded file to a temporary location
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(uploaded_file.read())
tmp_file_path = tmp_file.name
# Load the PDF using PyPDFLoader
loader = PyPDFLoader(tmp_file_path)
documents = loader.load()
documents = documents[:10]
st.session_state['ready'] = True
st.divider()
if st.session_state['ready']:
if 'generated' not in st.session_state:
st.session_state['generated'] = ["Welcome! You can now ask any questions regarding " + uploaded_file.name]
if 'past' not in st.session_state:
st.session_state['past'] = ["Hey!"]
# Container for chat history
response_container = st.container()
# Container for text box
container = st.container()
with container:
with st.form(key='my_form', clear_on_submit=True):
query = st.text_input("Enter your query:", key='input')
submit_button = st.form_submit_button(label='Send')
if submit_button and query:
graph_rag = GraphRAG()
graph_rag.process_documents(documents)
output = graph_rag.query(query)
st.session_state.past.append(query)
st.session_state.generated.append(output)
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user', avatar_style="thumbs")
message(st.session_state["generated"][i], key=str(i), avatar_style="fun-emoji")
if __name__ == "__main__":
main()
Conclusion :
This article explains how GraphRAG technology can help us answer global and complex questions more accurately, which is crucial for many application scenarios.
Furthermore, combined with the llama3:8b model, we have improved processing efficiency and speed and effectively reduced costs. It is good news for individual users, researchers, and businesses.
Try it yourself. You are welcome to share your test results in the comment section so that we can discuss them together.
If this article might be helpful to your friends, please forward it to them.
🧙♂️ I am an AI Generative expert! If you want to collaborate on a project, drop an inquiry here or Book a 1-on-1 Consulting Call With Me.