Imagine using an AI chatbot to get suggestions on a complex topic like a medical condition, a lawsuit, or an advanced mathematical question. You will expect a coherent and logical response but it tells you something in general and not completely relevant. It will be frustrating for you, right?
Traditional AI chatbots depend on matching keywords and have insufficient understanding of context so they fail in these cases. They generate responses from pre-trained knowledge and a limited number of documents, which makes them struggle with detailed relationships between different concepts.
GraphRAG solves this problem by modifying the retrieval and generation of AI responses. It stores knowledge as a graph of related concepts rather than separate chunks. This enables AI to think and reason like humans such as comprehending relations, identifying logic and extracting more relevant responses.
RAG is a technique used in artificial intelligence that improves responses by dynamically searching for information rather than depending on pre-trained knowledge. First it extracts relevant information and then generates the response. But there are some limitations to traditional RAG because they often depend on matching keywords instead of conceptual understanding and cannot find complex relations between different concepts.
GraphRAG solves these problems by arranging knowledge as a graph where nodes represent entities or concepts and edges represent the relationships between them. This helps AI to comprehend relations beyond matching keywords, find indirect but logical connections and retrieve intelligent responses.
GraphRAG has applications in various fields, such as:
This guide will cover:
Now, let's implement GraphRAG step by step.
We will implement our GraphRAG using Python.
Here are the required libraries:
spaCy: Used for natural language processing and is efficient for tasks like splitting words and recognizing relationships.
textacy: Built on top of Spacy and provides built-in functions for relationship extraction.
Neo4j: A database that stores data as a graph instead of tables.
matplotlib: Used for data visualization.
networkx: Used to draw graphs.
en_core_web_sm: A small English language model provided by spaCy.
Install the required libraries using:
We load the document that will be our knowledge base. We extract the text from it and remove any extra spaces.
You can find the document here: Link
import os
# Defining a function to load the document
def load_document(file):
all_text = ""
with open(file, "r", encoding="utf-8") as f:
all_text += f.read()
return all_text.strip()
# Loading document
doc_filename = "Document.txt"
documents_text = load_document(doc_filename)
We load the English language model provided by spaCy. We use this to extract subject-verb-object relationships using the built-in function textacy. extract.subject_verb_object_triples() provided by textacy. Then we print the relationships that were extracted.
Add the following code to your script[1] :
Can the code snippets below be better? I mean it looks stretched, affecting readability
import spacy, textacy
# Loading small English language model
nlp = spacy.load("en_core_web_sm")
# Defining a function to extract relationships
def extract_relationships(text):
doc = nlp(text)
relations = []
for subj, verb, obj in textacy. extract.subject_verb_object_triples(doc):
relations.append((subj.text if hasattr(subj, 'text') else subj,
verb.text if hasattr(verb, 'text') else verb,
obj.text if hasattr(obj, 'text') else obj))
return relations
# Extracting relationships
relations = extract_relationships(documents_text)
We import the Neo4j module that allows interaction with the Neo4j database. We specifically import the GraphDatabase class, which is used to establish a connection with Neo4j graph database. We provide connection details and establish a connection using a driver from GraphDatabase. We create a function to add entities and their relationships to the database using the MERGE command. Then we use this function in a loop for storing all entities and their relati[1] onships.
I think visibility is compromised on Google Docs. Please upload these pictures again.
from neo4j import GraphDatabase
# Neo4j connection details
URI = "YOUR_URI"
USERNAME = "YOUR_USERNAME"
PASSWORD = "YOUR_PASSWORD"
driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
# Defining a function to store entities and relationships in Neo4j
def add_relationship(tx, entity1, relation, entity2):
query = f"""
MERGE (e1:Entity {{name: $entity1}})
MERGE (e2:Entity {{name: $entity2}})
MERGE (e1)-[:`{relation}`]->(e2)
"""
tx.run(query, entity1=entity1, relation=relation, entity2=entity2)
#Storing data in the Neo4j database
with driver.session() as session:
for subj, rel, obj in relations:
session.execute_write(add_relationship, str(subj), str(rel), str(obj))
You can visualize this using the following code:
# Defining a function to fetch data from a graph database
def fetch_graph_data(tx):
query = """
MATCH (a)-[r]->(b)
RETURN a.name AS source, type(r) AS relation, b.name AS target
"""
result = tx.run(query)
# Using an empty list to store edges
edges = []
for record in result:
edges.append((record["source"], record["target"], record["relation"]))
return edges
# Fetching data from the database
with driver.session() as session:
edges = session.read_transaction(fetch_graph_data)
# Closing the connection
driver.close()
We create a directed graph using NetworkX and define a function to bring the nodes’ and relations’ text into the proper format. Using a for loop, we add the formatted edges to the graph. We define the figure size using matplotlib. We use a shell layout in which nodes are in concentric circles. We draw the edges and define all the properties. Then we set the properties of edge labels and draw them. Finally, we set a title and show the plot.
import matplotlib.pyplot as plt
import networkx as nx
# Creating a directed graph
G = nx.DiGraph()
# Defining a function to properly format text
def modify_string(s):
return s[1:-1].replace(',', '')
# Adding edges to the graph
for src, dest, rel in edges:
src = modify_string(src)
dest = modify_string(dest)
rel = modify_string(rel)
G.add_edge(src, dest, label=rel)
# Defining the figure size
plt.figure(figsize=(8, 6))
pos = nx.shell_layout(G)
# Drawing edges
nx.draw(G, pos, with_labels=True, node_size=800, node_color="lightblue", edge_color="gray", font_size=6, font_weight="bold")
# Drawing edge labels
edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=6, font_color="red")
# Setting the title
plt.title("Neo4j Graph Visualization with Matplotlib")
plt.show()
We define a function to return the relevant entities from the database by using MATCH command that checks if an entity contains the search text. We use OPTIONAL MATCH for case when an entity might exist but not have any outgoing relationships to other nodes. We define a function that returns relevant data from the database by initiating a driver session and reading the data. We then use an example query and store the relevant information.
# Defining a function to query the database
def query_neo4j(tx, query_text):
query = """
MATCH (e) WHERE e.name CONTAINS $query_text
OPTIONAL MATCH (e)-[r]->(related)
RETURN e.name AS entity, labels(e) AS label, related.name AS related_entity, type(r) AS relation
"""
result = tx.run(query, query_text=query_text)
return [record for record in result]
# Defining a function to extract relevant entities and their relationships
def get_relevant_knowledge(query_text):
with driver.session() as session:
results = session.read_transaction(query_neo4j, query_text)
return results
# Example query
query = "AI"
# Extracting relevant information from database
We import requests, a library that is used to send HTTP requests in Python. We add the URL endpoint for our API call. In our case, we are using the Mistral-7B-Instruct-v0.3 model from Hugging Face. The HEADERS dictionary contains an authorization header for the API key. We define a function to generate a response from LLM by passing the prompt to it as input and setting parameters. We set the maximum response length to 200 tokens and enable sampling. We provide the relevant information to the model to improve its responses and generate a response using user input and our extracted information. The model combines these two to generate an intelligent response.
We import requests, a library that is used to send HTTP requests in Python. We add the URL endpoint for our API call. In our case, we are using the Mistral-7B-Instruct-v0.3 model from Hugging Face. The HEADERS dictionary contains an authorization header for the API key. We define a function to generate a response from LLM by passing the prompt to it as input and setting parameters. We set the maximum response length to 200 tokens and enable sampling. We provide the relevant information to the model to improve its responses and generate a response using user input and our extracted information. The model combines these two to generate an intelligent response.
# Converting the response into a proper format
response = str(response[0])
answer_start = response.find("Answer:") + len("Answer:")
cleaned_response = response[answer_start:].strip()
print("AI Response:", cleaned_response)
GraphRAG improves AI by storing relationships in an organized way using graphs and improving its logical understanding. This guide explained the benefits, applications and implementation of GraphRAG. This method is useful for making smart and intelligent chatbots.