1.概述
https://github.com/gusye1234/nano-graphrag
😭 GraphRAG很强大,但官方的实现阅读或修改起来非常困难。
😊 本项目提供了一个更小、更快、更简洁的 GraphRAG,同时保留了核心功能。
以下是该项目的详细代码注释,作为学习记录和后续修改代码的参考。
2.分模块注释以及分析
为了节省时间只介绍核心模块,即/nano_graphrag中的代码。
2.1 prompt.py
GRAPH_FIELD_SEP = "<SEP>"
PROMPTS = {}
PROMPTS["claim_extraction"],PROMPTS["community_report"],PROMPTS["entity_extraction"],PROMPTS["summarize_entity_descriptions"],
PROMPTS["entiti_continue_extraction"],PROMPTS["entiti_if_loop_extraction"],
PROMPTS["DEFAULT_ENTITY_TYPES"],PROMPTS["DEFAULT_TUPLE_DELIMITER"],PROMPTS["DEFAULT_RECORD_DELIMITER"],PROMPTS["DEFAULT_COMPLETION_DELIMITER"]
PROMPTS["local_rag_response"],PROMPTS["global_reduce_rag_response"],
PROMPTS["fail_response"],PROMPTS["process_tickers"]
使用的prompt是与官方范例相同的内容,这里就不再赘述了。
2.2 _llm.py
函数列表
gpt_4o_complete,
gpt_4o_mini_complete,
openai_embedding,
azure_gpt_4o_complete,
azure_openai_embedding,
azure_gpt_4o_mini_complete
看着很复杂其实相当于只有两个函数,openai_embedding是调用embedding模型。
python">@retry(stop=stop_after_attempt(5),wait=wait_exponential(multiplier=1, min=4, max=10),retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_complete_if_cache(model, prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:"""使用OpenAI的API完成文本生成,如果缓存中已存在相同请求则返回缓存结果。参数:- model: 使用的OpenAI模型名称。- prompt: 用户的输入提示文本。- system_prompt: (可选)系统提示信息,用以引导模型的响应风格或内容。- history_messages: (可选)历史对话消息列表,用于聊天上下文。- **kwargs: 其他额外参数,如缓存存储对象。返回:- str: API响应中模型生成的文本内容。"""openai_async_client = AsyncOpenAI()hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)messages = []if system_prompt:messages.append({"role": "system", "content": system_prompt})messages.extend(history_messages)messages.append({"role": "user", "content": prompt})# 如果缓存对象存在,计算当前请求的哈希值,尝试从缓存中获取结果if hashing_kv is not None:args_hash = compute_args_hash(model, messages)if_cache_return = await hashing_kv.get_by_id(args_hash)if if_cache_return is not None:return if_cache_return["return"]response = await openai_async_client.chat.completions.create(model=model, messages=messages, **kwargs)# 如果有缓存对象,将响应结果存入缓存if hashing_kv is not None:await hashing_kv.upsert({args_hash: {"return": response.choices[0].message.content, "model": model}})await hashing_kv.index_done_callback()return response.choices[0].message.contentasync def gpt_4o_complete(prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:return await openai_complete_if_cache("gpt-4o",prompt,system_prompt=system_prompt,history_messages=history_messages,**kwargs,)
其他四个都源于openai_complete_if_cache,顾名思义,调用LLM回答问题,如果有缓存则优先调用缓存中的结果,避免资源浪费。
python">@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(stop=stop_after_attempt(5),wait=wait_exponential(multiplier=1, min=4, max=10),retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
)
async def openai_embedding(texts: list[str]) -> np.ndarray:openai_async_client = AsyncOpenAI()response = await openai_async_client.embeddings.create(model="text-embedding-3-small", input=texts, encoding_format="float")return np.array([dp.embedding for dp in response.data])
记录缓存使用的格式是BaseKVStorage在base里定义。
2.3_utils.py
函数列表
logger
日志记录器,几乎每个模块都有调用
python">logger = logging.getLogger("nano-graphrag")
convert_response_to_json
将响应字符串转换为JSON格式的数据。通过系统变量convert_response_to_json_func: callable = convert_response_to_json来调用。
python"># 通过正则化匹配,从字符串中提取出JSON字符串
def locate_json_string_body_from_string(content: str) -> Union[str, None]:"""Locate the JSON string body from a string"""maybe_json_str = re.search(r"{.*}", content, re.DOTALL)if maybe_json_str is not None:return maybe_json_str.group(0)else:return None# 将响应字符串转换为JSON格式的数据。
def convert_response_to_json(response: str) -> dict:json_str = locate_json_string_body_from_string(response)assert json_str is not None, f"Unable to parse JSON from response: {response}"try:data = json.loads(json_str)return dataexcept json.JSONDecodeError as e:logger.error(f"Failed to parse JSON: {json_str}")raise e from None
truncate_list_by_token_size
根据token大小截断列表数据。
python">def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):"""Truncate a list of data by token size""""""根据token大小截断列表数据。该函数的目的是确保列表中数据的总token数不超过指定的最大token大小。当数据的总token数超过最大允许大小时,函数将返回截断后的列表。参数:- list_data: list, 需要截断的列表,其中每个元素为一个数据项。- key: callable, 用于从列表数据项中提取用于计算token大小的字符串的函数。- max_token_size: int, 允许的最大token大小,用于决定列表数据的截断点。返回:- 截断后的列表。如果max_token_size小于等于0,返回空列表。注意:- 该函数使用tiktoken对字符串进行编码并计算token数量,请确保在使用前已安装tiktoken库。- 截断操作基于累计token数量首次超过max_token_size发生的索引位置。"""if max_token_size <= 0:return []tokens = 0for i, data in enumerate(list_data):tokens += len(encode_string_by_tiktoken(key(data)))if tokens > max_token_size:return list_data[:i]return list_data
compute_mdhash_id,compute_args_hash
计算哈希值的辅助函数,第二个在_llm.py中已经用过了。
write_json,load_json
用来读写json对象的辅助函数
pack_user_ass_to_openai_messages
将用户和助手的对话打包为OpenAI消息格式。
接受一系列字符串参数,成对地将它们包装成交替的用户和助手角色的消息。
这对于将对话历史记录转换为可供OpenAI的API处理的格式特别有用。在_op.py里就是调用给history的。
python">def pack_user_ass_to_openai_messages(*args: str):"""将用户和助手的对话打包为OpenAI消息格式。该函数接受一系列字符串参数,成对地将它们包装成交替的用户和助手角色的消息。这对于将对话历史记录转换为可供OpenAI的API处理的格式特别有用。在_op.py里就是调用给history的。参数:*args (str): 一个或多个字符串参数,表示用户和助手之间的对话交替发言。返回:list: 一个字典列表,每个字典包含两个键值对:- 'role': 表示消息发送者的角色,根据参数序列中的位置交替为'user'或'assistant'。- 'content': 发送者发送的消息内容,来自输入参数序列中的对应位置。"""roles = ["user", "assistant"]return [{"role": roles[i % 2], "content": content} for i, content in enumerate(args)]
is_float_regex
判断是否是浮点数
split_string_by_multi_markers
通过多个标记(markers)分割字符串
list_of_list_to_csv
将多维列表转换为CSV格式,用在社区结构部分。具体是_pack_single_community_describe函数。
clean_str
清理字符串,在_op.py的_handle_single_entity_extraction函数中用到。
class EmbeddingFunc
定义一个用于嵌入的函数类。
limit_async_func_call
为异步函数添加最大并发调用次数的限制。
python">def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):"""Add restriction of maximum async calling times for a async func""""""为异步函数添加最大并发调用次数的限制。参数:- max_size: int, 允许的最大并发调用次数。- waitting_time: float, 当达到最大并发调用次数时,每次检查间隔的时间(秒),默认为0.0001秒。返回:- 返回一个装饰器函数,用于包装需要限制并发调用的异步函数。"""def final_decro(func):"""Not using async.Semaphore to aovid use nest-asyncio"""__current_size = 0@wraps(func)async def wait_func(*args, **kwargs):nonlocal __current_size# 如果当前调用数量达到最大值,等待一段时间后再检查while __current_size >= max_size:await asyncio.sleep(waitting_time)__current_size += 1result = await func(*args, **kwargs)__current_size -= 1return resultreturn wait_funcreturn final_decro
wrap_embedding_func_with_attrs
用属性包装一个函数。
该函数返回一个装饰器,可用于为一个函数添加额外的属性参数。这些属性
被用于在函数执行时提供或记录额外的信息,比如函数的来源、类型等。
这个在_llm.py中调用,是给embedding函数添加了embedding_dim=1536, max_token_size=8192属性
2.4 base.py
QueryParam
查询参数类,用于定义查询时的各种配置选项。
python">class QueryParam:mode: Literal["local", "global", "naive"] = "global"only_need_context: bool = Falseresponse_type: str = "Multiple Paragraphs"level: int = 2top_k: int = 20# naive searchnaive_max_token_for_text_unit = 12000# local searchlocal_max_token_for_text_unit: int = 4000 # 12000 * 0.33local_max_token_for_local_context: int = 4800 # 12000 * 0.4local_max_token_for_community_report: int = 3200 # 12000 * 0.27local_community_single_one: bool = False# global searchglobal_min_community_rating: float = 0global_max_consider_community: float = 512global_max_token_for_community_report: int = 16384global_special_community_map_llm_kwargs: dict = field(default_factory=lambda: {"response_format": {"type": "json_object"}})
TextChunkSchema,SingleCommunitySchema
文本块,社区存储结构
class CommunitySchema
添加了字符串格式以及json格式报告的社区class,在_storage.py中有更详细的定义
class StorageNameSpace
存储命名空间类,用于管理存储操作。有namespace和global_config两个属性。是后面几个类的基础父类。
class BaseVectorStorage
基础向量存储类,继承自 StorageNameSpace。补充了embedding_func和meta_fields属性。用来定义各种向量存储,比如entities_vdb: BaseVectorStorage。
方法:
query: 查询方法,具体函数在_storage.py中,下同。
upsert: 插入或更新方法。
class BaseKVStorage
基础键值存储类,继承自 StorageNameSpace。用来定义各种键值存储,比如
community_reports: BaseKVStorage[CommunitySchema],
text_chunks_db: BaseKVStorage[TextChunkSchema]。
方法:
all_keys: 获取所有键。
get_by_id: 根据 ID 获取数据。
get_by_ids: 根据多个 ID 获取数据。
filter_keys: 筛选出不存在的键。
upsert: 插入或更新数据。
drop: 删除整个存储空间。
class BaseGraphStorage
基础图存储类,继承自 StorageNameSpace。用来定义各种图结构的存储,比知识图谱knwoledge_graph_inst: BaseGraphStorage
方法:
has_node: 判断节点是否存在。
has_edge: 判断边是否存在。
node_degree: 获取节点度数。
edge_degree: 获取边的度数。
get_node: 获取节点信息。
get_edge: 获取边信息。
get_node_edges: 获取节点的所有边。
upsert_node: 插入或更新节点。
upsert_edge: 插入或更新边。
clustering: 进行图聚类。
community_schema: 获取社区结构。
embed_nodes: 对节点进行嵌入。
2.5 _storage.py
里面存储了大量数据操作的函数实现,并且定义了知识图谱存储方式的类别,所以代码很长。
class JsonKVStorage
继承自BaseKVStorage,基本相同,并且包含方法的具体实现代码。
python">class JsonKVStorage(BaseKVStorage):def __post_init__(self):# 初始化时,根据全局配置确定工作目录,以便获取文件完整路径working_dir = self.global_config["working_dir"]# 根据命名空间生成特定的 JSON 文件名,用于存储键值数据self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")# 加载存储的数据,如果文件不存在或为空,则初始化为空字典self._data = load_json(self._file_name) or {}# 打印日志,显示加载的数据条数logger.info(f"Load KV {self.namespace} with {len(self._data)} data")# 获取所有的键列表async def all_keys(self) -> list[str]:return list(self._data.keys())# 索引操作完成后,将当前数据写入 JSON 文件async def index_done_callback(self):write_json(self._data, self._file_name)# 通过 ID 获取数据async def get_by_id(self, id):return self._data.get(id, None)async def get_by_ids(self, ids, fields=None):"""根据ID列表获取数据项。参数:ids (list): 需要获取数据的ID列表。fields (list, 可选): 限制返回数据中的字段。如果未提供,默认为None,将返回完整数据项。返回:list: 包含按指定ID列表顺序排列的数据项的列表。如果某些ID未找到数据项,则相应位置为None。"""if fields is None:return [self._data.get(id, None) for id in ids]return [(# 如果数据项存在,并且ID在_data字典中,则构建一个仅包含fields中字段的新字典{k: v for k, v in self._data[id].items() if k in fields}# 检查_id是否在数据集中,以避免KeyErrorif self._data.get(id, None)else None)for id in ids]# 过滤出不在数据存储中的键列表async def filter_keys(self, data: list[str]) -> set[str]:return set([s for s in data if s not in self._data])# 插入或更新数据async def upsert(self, data: dict[str, dict]):self._data.update(data)# 清空当前存储数据async def drop(self):self._data = {}
class NanoVectorDBStorage
继承自BaseVectorDBStorage,还是用了类别NanoVectorDB。
python">class NanoVectorDBStorage(BaseVectorStorage):# 余弦相似度阈值,决定返回的结果质量cosine_better_than_threshold: float = 0.2def __post_init__(self):# 初始化向量数据库存储文件和嵌入配置self._client_file_name = os.path.join(self.global_config["working_dir"], f"vdb_{self.namespace}.json")self._max_batch_size = self.global_config["embedding_batch_num"]# 初始化向量数据库客户端(NanoVectorDB),并设置嵌入维度self._client = NanoVectorDB(self.embedding_func.embedding_dim, storage_file=self._client_file_name)# 从全局配置中获取查询的相似度阈值,或使用默认值self.cosine_better_than_threshold = self.global_config.get("query_better_than_threshold", self.cosine_better_than_threshold)async def upsert(self, data: dict[str, dict]):"""插入或更新向量数据。该方法用于将字典形式的数据插入或更新到向量数据库中。数据首先被转换成适合插入的格式,然后分批处理,以避免一次性插入过多数据导致的性能问题。之后,使用异步方式计算各批次数据的嵌入向量,并将这些向量附加到数据条目中,最后调用客户端的插入或更新方法完成操作。参数:data: dict[str, dict] - 一个字典,键是数据的唯一标识,值是包含实际数据内容的字典。返回:插入或更新操作的结果。"""logger.info(f"Inserting {len(data)} vectors to {self.namespace}")if not len(data):logger.warning("You insert an empty data to vector DB")return []# 将数据转换为适合插入的列表,并提取内容list_data = [{"__id__": k,**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},}for k, v in data.items()]contents = [v["content"] for v in data.values()]# 将数据按批次进行处理batches = [contents[i : i + self._max_batch_size]for i in range(0, len(contents), self._max_batch_size)]# 异步计算各批次数据的嵌入向量embeddings_list = await asyncio.gather(*[self.embedding_func(batch) for batch in batches])# 将所有批次的嵌入向量合并为一个大数组embeddings = np.concatenate(embeddings_list)# 将计算得到的嵌入向量附加到每个数据条目中for i, d in enumerate(list_data):d["__vector__"] = embeddings[i]# 调用客户端的插入或更新方法,完成数据的插入或更新results = self._client.upsert(datas=list_data)return resultsasync def query(self, query: str, top_k=5):"""根据提供的查询字符串获取最相关的文档。此异步方法使用预训练的embedding函数将查询转换为嵌入表示,然后在嵌入索引中搜索与查询最相似的文档。参数:- query: str,用户查询的字符串。- top_k: int,返回最相关的文档数量,默认为5。返回:- 一个列表,包含最相关的文档及其与查询的相似度距离。"""embedding = await self.embedding_func([query])embedding = embedding[0]results = self._client.query(query=embedding,top_k=top_k,better_than_threshold=self.cosine_better_than_threshold,)# 整理结果,添加文档id和距离信息results = [{**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results]return resultsasync def index_done_callback(self):self._client.save()
HNSWVectorStorage(另一种向量存储方式,这里先略过)
class NetworkXStorage
继承自BaseGraphStorage,存储图结构的数据。
可以先参考这个链接学习一些图(graph)知识。
图的一些基本知识(连通图、连通分量、最小生成树等知识的基本介绍)-CSDN博客
这里社区发现算法还是调用的外部库,和微软开源的GraphRAG一致。
python">class NetworkXStorage(BaseGraphStorage):# 加载并返回一个NetworkX图,存储格式为graphml。@staticmethoddef load_nx_graph(file_name) -> nx.Graph:if os.path.exists(file_name):return nx.read_graphml(file_name)return None# 将NetworkX图写入graphml文件。@staticmethoddef write_nx_graph(graph: nx.Graph, file_name):logger.info(f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")nx.write_graphml(graph, file_name)@staticmethoddef stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.pyReturn the largest connected component of the graph, with nodes and edges sorted in a stable way.""""""返回图的最大连通分量,并以稳定的方式排序节点和边。参数:graph (nx.Graph): 输入的 NetworkX 图。返回:nx.Graph: 输入图的最大连通分量,以稳定方式排序。"""from graspologic.utils import largest_connected_componentgraph = graph.copy()graph = cast(nx.Graph, largest_connected_component(graph))node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignoregraph = nx.relabel_nodes(graph, node_mapping)return NetworkXStorage._stabilize_graph(graph)@staticmethoddef _stabilize_graph(graph: nx.Graph) -> nx.Graph:"""Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.pyEnsure an undirected graph with the same relationships will always be read the same way.""""""确保无向图以相同的关系读取时始终相同。参数:graph (nx.Graph): 输入的网络图。返回:nx.Graph: 经过稳定处理的网络图。"""# 根据输入图的类型初始化一个新的图实例 fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()# 对节点进行排序,以确保节点的添加顺序一致sorted_nodes = graph.nodes(data=True)sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])# 向新图中添加排序后的节点fixed_graph.add_nodes_from(sorted_nodes)# 将边数据存储到列表中以便后续处理edges = list(graph.edges(data=True))# 如果图不是有向图,则对边进行排序,以确保边的顺序一致if not graph.is_directed():def _sort_source_target(edge):source, target, edge_data = edgeif source > target:temp = sourcesource = targettarget = tempreturn source, target, edge_dataedges = [_sort_source_target(edge) for edge in edges]# 定义获取边的键的函数,用于后续边的排序def _get_edge_key(source: Any, target: Any) -> str:return f"{source} -> {target}"# 对边进行排序edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))# 向新图中添加排序后的边fixed_graph.add_edges_from(edges)return fixed_graphdef __post_init__(self):"""初始化函数,用于加载图数据并初始化相关属性。该函数首先根据全局配置中的工作目录和实例的命名空间来确定graphml文件的路径。然后尝试从该路径加载已存在的图数据。如果图数据存在,则使用NetworkXStorage加载,并记录日志信息包括图的节点数和边数。如果图数据不存在,则初始化一个新的无向图。最后,初始化两个算法字典,分别用于图的聚类算法和节点嵌入算法。"""self._graphml_xml_file = os.path.join(self.global_config["working_dir"], f"graph_{self.namespace}.graphml")preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)if preloaded_graph is not None:logger.info(f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges")self._graph = preloaded_graph or nx.Graph()self._clustering_algorithms = {"leiden": self._leiden_clustering,}self._node_embed_algorithms = {"node2vec": self._node2vec_embed,}# 将当前存储的图写入到GraphML文件中async def index_done_callback(self):NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)async def has_node(self, node_id: str) -> bool:"""异步检查图中是否存在指定节点。下面几个函数类似,就不做标注了,都是调用NetWorkX的函数。该方法主要用于确定图结构中是否包含特定的节点。它通过调用底层图对象的has_node方法,以高效的方式查询节点是否存在。参数:node_id (str): 要检查的节点的唯一标识符。返回:bool: 如果图中存在该节点,则返回True,否则返回False。"""return self._graph.has_node(node_id)async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:return self._graph.has_edge(source_node_id, target_node_id)async def get_node(self, node_id: str) -> Union[dict, None]:return self._graph.nodes.get(node_id)# 获取指定节点的度数。async def node_degree(self, node_id: str) -> int:# [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0# 计算两个节点的度数之和async def edge_degree(self, src_id: str, tgt_id: str) -> int:return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0)async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:return self._graph.edges.get((source_node_id, target_node_id))async def get_node_edges(self, source_node_id: str):if self._graph.has_node(source_node_id):return list(self._graph.edges(source_node_id))return Noneasync def upsert_node(self, node_id: str, node_data: dict[str, str]):self._graph.add_node(node_id, **node_data)async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]):self._graph.add_edge(source_node_id, target_node_id, **edge_data)# 根据指定的算法执行聚类操作。async def clustering(self, algorithm: str):if algorithm not in self._clustering_algorithms:raise ValueError(f"Clustering algorithm {algorithm} not supported")await self._clustering_algorithms[algorithm]()async def community_schema(self) -> dict[str, SingleCommunitySchema]:"""生成社区架构的异步方法。该方法计算并组织图中所有节点的社区结构信息。它通过分析节点的集群信息,构建不同层级的社区,并计算社区的各种属性,如层次、节点数、边数等。Returns:一个字典,键为社区键,值为SingleCommunitySchema对象,包含每个社区的详细信息。"""# 初始化结果字典,为每个社区准备默认属性results = defaultdict(lambda: dict(level=None,title=None,edges=set(),nodes=set(),chunk_ids=set(),occurrence=0.0,sub_communities=[],))max_num_ids = 0# 用于存储不同层级的社区levels = defaultdict(set)# 遍历图中的所有节点,收集社区信息for node_id, node_data in self._graph.nodes(data=True):# 如果节点没有集群信息,则跳过if "clusters" not in node_data:continueclusters = json.loads(node_data["clusters"])this_node_edges = self._graph.edges(node_id)# 遍历节点的所有集群信息for cluster in clusters:# 提取并更新社区的层级信息level = cluster["level"]cluster_key = str(cluster["cluster"])levels[level].add(cluster_key)results[cluster_key]["level"] = levelresults[cluster_key]["title"] = f"Cluster {cluster_key}"results[cluster_key]["nodes"].add(node_id)results[cluster_key]["edges"].update([tuple(sorted(e)) for e in this_node_edges])results[cluster_key]["chunk_ids"].update(node_data["source_id"].split(GRAPH_FIELD_SEP))# 计算最大的chunk_ids数量,用于后续计算出现率max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))# 对层级进行排序,以便后续处理ordered_levels = sorted(levels.keys())# 构建子社区关系for i, curr_level in enumerate(ordered_levels[:-1]):next_level = ordered_levels[i + 1]this_level_comms = levels[curr_level]next_level_comms = levels[next_level]# compute the sub-communities by nodes intersection# 通过节点交集计算子社区关系for comm in this_level_comms:results[comm]["sub_communities"] = [cfor c in next_level_commsif results[c]["nodes"].issubset(results[comm]["nodes"])]# 处理并标准化结果字典中的数据for k, v in results.items():v["edges"] = list(v["edges"])v["edges"] = [list(e) for e in v["edges"]]v["nodes"] = list(v["nodes"])v["chunk_ids"] = list(v["chunk_ids"])v["occurrence"] = len(v["chunk_ids"]) / max_num_idsreturn dict(results)def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):"""将聚类数据分配到子图该方法将给定的聚类数据分配到图中的相应节点。每个节点的聚类信息以JSON格式存储在节点的'clusters'属性中。参数:cluster_data: 字典,包含节点ID和对应的聚类列表。每个聚类是一个字典,包含节点在不同聚类中的信息。"""# 遍历聚类数据字典中的每个节点及其聚类信息for node_id, clusters in cluster_data.items():# 将节点的聚类信息转换为JSON格式,并存储在图节点的'clusters'属性中self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)async def _leiden_clustering(self):"""基于Leiden算法的图聚类异步方法。本方法使用Leiden算法对图进行聚类,以发现图中的社区结构。它从当前图的稳定最大连通组件开始,根据全局配置中的参数进行聚类,并将聚类结果转换为子图。"""from graspologic.partition import hierarchical_leiden# 获取图的稳定最大连通组件graph = NetworkXStorage.stable_largest_connected_component(self._graph)# 应用Leiden算法进行层次聚类community_mapping = hierarchical_leiden(graph,max_cluster_size=self.global_config["max_graph_cluster_size"],random_seed=self.global_config["graph_cluster_seed"],)# 准备节点社区字典node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)# 准备用于跟踪各级别社区数量的字典__levels = defaultdict(set)# 遍历社区映射,构建节点社区字典和级别跟踪字典for partition in community_mapping:level_key = partition.levelcluster_id = partition.clusternode_communities[partition.node].append({"level": level_key, "cluster": cluster_id})__levels[level_key].add(cluster_id)# 转换节点社区字典和级别跟踪字典为最终形式node_communities = dict(node_communities)__levels = {k: len(v) for k, v in __levels.items()}# 记录每级社区的数量信息logger.info(f"Each level has communities: {dict(__levels)}")# 将聚类数据转换为子图self._cluster_data_to_subgraphs(node_communities)# 根据指定的算法进行节点嵌入。async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:if algorithm not in self._node_embed_algorithms:raise ValueError(f"Node embedding algorithm {algorithm} not supported")return await self._node_embed_algorithms[algorithm]()async def _node2vec_embed(self):"""异步方法,用于通过node2vec算法嵌入图结构数据。该方法使用graspologic库的node2vec_embed函数,根据内部图结构和配置参数进行图嵌入。它首先调用嵌入函数,然后提取嵌入结果中节点的ID,并返回嵌入向量和节点ID列表。"""from graspologic import embedembeddings, nodes = embed.node2vec_embed(self._graph,**self.global_config["node2vec_params"],)nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]return embeddings, nodes_ids
2.6 _op.py
一看就知道,最重要的几个函数基本都封装在这里,所以代码量也来到了作者所说的800行级别(实际是1300行)。
┓( ´∀` )┏
函数列表
chunking_by_token_size
用来进行分块。
python">def chunking_by_token_size(content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):"""根据token大小对文本进行分块。该函数用于将给定的文本内容按照指定的token大小限制进行分块,同时保证相邻块之间有重叠。主要用于处理大文本,使其能够适应如OpenAI的GPT系列模型的输入限制。参数:- content: str, 待分块的文本内容。- overlap_token_size: int, 默认128. 相邻文本块之间的重叠token数。- max_token_size: int, 默认1024. 每个文本块的最大token数。- tiktoken_model: str, 默认"gpt-4o". 用于token化和去token化的tiktoken模型。返回:- List[Dict[str, Any]], 包含每个文本块的tokens数量、文本内容和块顺序索引的列表。"""# 使用指定的tiktoken模型对文本进行token化tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)# 初始化存储分块结果的列表results = []# 遍历tokens,根据max_token_size和overlap_token_size进行分块for index, start in enumerate(range(0, len(tokens), max_token_size - overlap_token_size)):# 根据当前分块的起始位置和最大token数限制,获取分块的tokenschunk_content = decode_tokens_by_tiktoken(tokens[start : start + max_token_size], model_name=tiktoken_model)# 将当前分块的tokens数量、文本内容和块顺序索引添加到结果列表中results.append({"tokens": min(max_token_size, len(tokens) - start),"content": chunk_content.strip(),"chunk_order_index": index,})return results
extract_entities
内嵌了实体合并的代码,调用了很多_op.py中的函数,这里就不逐个赘述了。
python">async def extract_entities(chunks: dict[str, TextChunkSchema],knwoledge_graph_inst: BaseGraphStorage,entity_vdb: BaseVectorStorage,global_config: dict,
) -> Union[BaseGraphStorage, None]:"""异步函数extract_entities从文本块中提取实体并更新知识图谱。参数:chunks (dict[str, TextChunkSchema]): 文本块字典,键为文本块标识,值为包含文本块内容的TextChunkSchema对象。knwoledge_graph_inst (BaseGraphStorage): 知识图谱实例,用于存储提取的实体和关系。entity_vdb (BaseVectorStorage): 实体向量数据库实例,用于存储实体的向量表示。global_config (dict): 全局配置字典,包含模型函数、最大迭代次数等参数。返回:Union[BaseGraphStorage, None]: 更新后的知识图谱实例,如果没有提取到任何实体,则返回None。"""# 从全局配置中提取模型函数和实体提取最大迭代次数use_llm_func: callable = global_config["best_model_func"]entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]# 将文本块排序,以便按顺序处理ordered_chunks = list(chunks.items())# 准备实体提取的提示模板和上下文基础信息entity_extract_prompt = PROMPTS["entity_extraction"]context_base = dict(tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),)continue_prompt = PROMPTS["entiti_continue_extraction"]if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]# 初始化计数器already_processed = 0already_entities = 0already_relations = 0async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):"""异步函数_process_single_content处理单个文本块的内容,提取实体并更新知识图谱。参数:chunk_key_dp (tuple[str, TextChunkSchema]): 文本块的键值对,包含文本块标识和内容。返回:dict: 包含从文本块中提取的可能的节点和边的字典。"""# 初始化非局部变量,用于跟踪处理进度和统计信息nonlocal already_processed, already_entities, already_relations# 解析文本块的键和数据chunk_key = chunk_key_dp[0]chunk_dp = chunk_key_dp[1]content = chunk_dp["content"]# 构建提示信息并调用模型函数提取实体hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)final_result = await use_llm_func(hint_prompt)# 构建对话历史并进行多次迭代以提取更多实体history = pack_user_ass_to_openai_messages(hint_prompt, final_result)for now_glean_index in range(entity_extract_max_gleaning):glean_result = await use_llm_func(continue_prompt, history_messages=history)history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)final_result += glean_result# 检查是否继续迭代if now_glean_index == entity_extract_max_gleaning - 1:breakif_loop_result: str = await use_llm_func(if_loop_prompt, history_messages=history)if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()if if_loop_result != "yes":break# 解析结果并更新知识图谱records = split_string_by_multi_markers(final_result,[context_base["record_delimiter"], context_base["completion_delimiter"]],)maybe_nodes = defaultdict(list)maybe_edges = defaultdict(list)for record in records:record = re.search(r"\((.*)\)", record)if record is None:continuerecord = record.group(1)record_attributes = split_string_by_multi_markers(record, [context_base["tuple_delimiter"]])# 处理实体提取if_entities = await _handle_single_entity_extraction(record_attributes, chunk_key)if if_entities is not None:maybe_nodes[if_entities["entity_name"]].append(if_entities)continue# 处理关系提取if_relation = await _handle_single_relationship_extraction(record_attributes, chunk_key)if if_relation is not None:maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(if_relation)# 更新处理进度和统计信息already_processed += 1already_entities += len(maybe_nodes)already_relations += len(maybe_edges)now_ticks = PROMPTS["process_tickers"][already_processed % len(PROMPTS["process_tickers"])]print(f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",end="",flush=True,)return dict(maybe_nodes), dict(maybe_edges)# 并发处理所有文本块# use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callingsresults = await asyncio.gather(*[_process_single_content(c) for c in ordered_chunks])print() # clear the progress barmaybe_nodes = defaultdict(list)maybe_edges = defaultdict(list)for m_nodes, m_edges in results:for k, v in m_nodes.items():maybe_nodes[k].extend(v)for k, v in m_edges.items():# it's undirected graphmaybe_edges[tuple(sorted(k))].extend(v)all_entities_data = await asyncio.gather(*[_merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)for k, v in maybe_nodes.items()])await asyncio.gather(*[_merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)for k, v in maybe_edges.items()])if not len(all_entities_data):logger.warning("Didn't extract any entities, maybe your LLM is not working")return Noneif entity_vdb is not None:data_for_vdb = {compute_mdhash_id(dp["entity_name"], prefix="ent-"): {"content": dp["entity_name"] + dp["description"],"entity_name": dp["entity_name"],}for dp in all_entities_data}await entity_vdb.upsert(data_for_vdb)return knwoledge_graph_inst
generate_community_report
生成社区报告,同样调用了很多其他函数。
python">async def generate_community_report(community_report_kv: BaseKVStorage[CommunitySchema],knwoledge_graph_inst: BaseGraphStorage,global_config: dict,
):"""异步生成社区报告函数该函数用于异步生成社区报告,根据提供的知识图谱实例和全局配置,从社区模式中提取数据,并利用语言模型生成社区报告。参数:- community_report_kv: 实现了社区模式存储的BaseKVStorage实例。- knwoledge_graph_inst: 实现了知识图谱存储的BaseGraphStorage实例。- global_config: 包含生成社区报告所需的各种配置项的字典。返回值:无返回值,但会打印处理进度,并将生成的社区报告存储在community_report_kv中。"""# 从全局配置中提取语言模型的额外参数、使用的模型函数和字符串到JSON的转换函数llm_extra_kwargs = global_config["special_community_report_llm_kwargs"]use_llm_func: callable = global_config["best_model_func"]use_string_json_convert_func: callable = global_config["convert_response_to_json_func"]# 加载社区报告的提示模板community_report_prompt = PROMPTS["community_report"]# 从知识图谱实例中获取所有社区模式,并初始化已处理社区计数器communities_schema = await knwoledge_graph_inst.community_schema()community_keys, community_values = list(communities_schema.keys()), list(communities_schema.values())already_processed = 0# 定义异步函数_form_single_community_report,用于生成单个社区的报告async def _form_single_community_report(community: SingleCommunitySchema, already_reports: dict[str, CommunitySchema]):nonlocal already_processed# 为当前社区生成描述文本describe = await _pack_single_community_describe(knwoledge_graph_inst,community,max_token_size=global_config["best_model_max_token_size"],already_reports=already_reports,global_config=global_config,)# 构建完整的prompt并使用语言模型生成响应prompt = community_report_prompt.format(input_text=describe)response = await use_llm_func(prompt, **llm_extra_kwargs)# 将响应转换为JSON格式,并更新已处理社区计数器和进度打印data = use_string_json_convert_func(response)already_processed += 1now_ticks = PROMPTS["process_tickers"][already_processed % len(PROMPTS["process_tickers"])]print(f"{now_ticks} Processed {already_processed} communities\r",end="",flush=True,)return data# 按社区级别排序,并从高到低处理levels = sorted(set([c["level"] for c in community_values]), reverse=True)logger.info(f"Generating by levels: {levels}")community_datas = {}for level in levels:# 筛选当前级别的社区,并并行生成报告this_level_community_keys, this_level_community_values = zip(*[(k, v)for k, v in zip(community_keys, community_values)if v["level"] == level])this_level_communities_reports = await asyncio.gather(*[_form_single_community_report(c, community_datas)for c in this_level_community_values])# 将生成的报告整合到社区数据字典中community_datas.update({k: {"report_string": _community_report_json_to_str(r),"report_json": r,**v,}for k, r, v in zip(this_level_community_keys,this_level_communities_reports,this_level_community_values,)})print() # clear the progress barawait community_report_kv.upsert(community_datas)
local_query
本地查询。
python">async def local_query(query,knowledge_graph_inst: BaseGraphStorage,entities_vdb: BaseVectorStorage,community_reports: BaseKVStorage[CommunitySchema],text_chunks_db: BaseKVStorage[TextChunkSchema],query_param: QueryParam,global_config: dict,
) -> str:"""根据查询和配置,执行本地查询并返回结果。该函数首先根据查询和一系列存储实例构建本地查询上下文。如果查询参数指示只需要上下文,则直接返回上下文字符串。如果上下文为空,则返回预定义的查询失败响应。否则,将根据上下文和查询参数构造系统提示,并使用全局配置中指定的最佳模型函数生成响应。参数:query: 查询字符串。knowledge_graph_inst: 知识图谱存储实例。entities_vdb: 实体向量存储实例。community_reports: 社区报告键值存储实例。text_chunks_db: 文本块键值存储实例。query_param: 查询参数,包含查询类型和响应类型等信息。global_config: 全局配置字典,包含模型函数等关键配置。返回:根据查询和上下文生成的响应字符串。"""# 从全局配置中获取最佳模型函数use_model_func = global_config["best_model_func"]# 构建本地查询上下文context = await _build_local_query_context(query,knowledge_graph_inst,entities_vdb,community_reports,text_chunks_db,query_param,)# 如果查询参数指示只需要上下文,则直接返回上下文if query_param.only_need_context:return context# 如果上下文为空,则返回查询失败的响应if context is None:return PROMPTS["fail_response"]# 根据预定义模板构造系统提示,包含上下文数据和查询参数中的响应类型sys_prompt_temp = PROMPTS["local_rag_response"]sys_prompt = sys_prompt_temp.format(context_data=context, response_type=query_param.response_type)# 使用最佳模型函数生成查询响应response = await use_model_func(query,system_prompt=sys_prompt,)return response
global_query
全局查询,并没有使用map-reduce的方法,而是直接把社区摘要进行了top-k排序。
python">async def global_query(query,knowledge_graph_inst: BaseGraphStorage,entities_vdb: BaseVectorStorage,community_reports: BaseKVStorage[CommunitySchema],text_chunks_db: BaseKVStorage[TextChunkSchema],query_param: QueryParam,global_config: dict,
) -> str:"""异步执行全局查询。此函数根据查询参数和配置从知识图谱、实体向量存储、社区报告及文本块数据库中获取相关信息,并对这些信息进行处理和排序,最终生成一个综合的回答。参数:- query: 查询字符串- knowledge_graph_inst: 知识图谱存储实例- entities_vdb: 实体向量存储实例- community_reports: 社区报告存储实例,存储类型为BaseKVStorage- text_chunks_db: 文本块存储实例,存储类型为BaseKVStorage- query_param: 查询参数对象,包含查询级别、最大考虑社区数等参数- global_config: 全局配置字典,包含最佳模型函数等关键配置返回:- str: 查询的最终回答字符串"""# 获取并筛选社区schemacommunity_schema = await knowledge_graph_inst.community_schema()community_schema = {k: v for k, v in community_schema.items() if v["level"] <= query_param.level}if not len(community_schema):return PROMPTS["fail_response"]use_model_func = global_config["best_model_func"]# 对社区schema进行排序sorted_community_schemas = sorted(community_schema.items(),key=lambda x: x[1]["occurrence"],reverse=True,)sorted_community_schemas = sorted_community_schemas[: query_param.global_max_consider_community]community_datas = await community_reports.get_by_ids([k[0] for k in sorted_community_schemas])community_datas = [c for c in community_datas if c is not None]community_datas = [cfor c in community_datasif c["report_json"].get("rating", 0) >= query_param.global_min_community_rating]community_datas = sorted(community_datas,key=lambda x: (x["occurrence"], x["report_json"].get("rating", 0)),reverse=True,)logger.info(f"Revtrieved {len(community_datas)} communities")# 映射社区数据并聚合支持点map_communities_points = await _map_global_communities(query, community_datas, query_param, global_config)final_support_points = []for i, mc in enumerate(map_communities_points):for point in mc:if "description" not in point:continuefinal_support_points.append({"analyst": i,"answer": point["description"],"score": point.get("score", 1),})final_support_points = [p for p in final_support_points if p["score"] > 0]if not len(final_support_points):return PROMPTS["fail_response"]final_support_points = sorted(final_support_points, key=lambda x: x["score"], reverse=True)final_support_points = truncate_list_by_token_size(final_support_points,key=lambda x: x["answer"],max_token_size=query_param.global_max_token_for_community_report,)points_context = []for dp in final_support_points:points_context.append(f"""----Analyst {dp['analyst']}----
Importance Score: {dp['score']}
{dp['answer']}
""")points_context = "\n".join(points_context)if query_param.only_need_context:return points_contextsys_prompt_temp = PROMPTS["global_reduce_rag_response"]response = await use_model_func(query,sys_prompt_temp.format(report_data=points_context, response_type=query_param.response_type),)return response
naive_query
朴素查询,不做重点。
2.7 graphrag.py
这回终于可以看主函数了,但有了前面的铺垫,具体代码就比较简单了。不过里面包含了不少增量式的处理方法,只要理解基本思想和功能看代码也会清楚一些。目标是可以实现输入a文档后,再输入b文档补充到a文档生成的知识图谱中,也就是图扩展。以后可以考虑在此基础上做图更新。
函数列表
always_get_an_event_loop()
python">def always_get_an_event_loop() -> asyncio.AbstractEventLoop:try:# If there is already an event loop, use it.loop = asyncio.get_event_loop()except RuntimeError:# If in a sub-thread, create a new event loop.logger.info("Creating a new event loop in a sub-thread.")loop = asyncio.new_event_loop()asyncio.set_event_loop(loop)return loop
class GraphRAG
主要函数类别,实现模块化调用。
insert()通过always_get_an_event_loop()调用异步函数ainsert(),以实现增量添加文本文件,query()和aquery()同理。
ainsert()——分块,实体关系提取,知识图谱聚类,社区报告的函数调用
aquery()——local,global,naive
python">class GraphRAG:# 表示工作目录的路径,默认是根据当前日期时间生成的目录working_dir: str = field(default_factory=lambda: f"./nano_graphrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}")# graph mode# 是否本地存储,默认是本地存储enable_local: bool = True# 是否启用朴素rag,默认不启用enable_naive_rag: bool = False# text chunkingchunk_func: Callable[[str, Optional[int], Optional[int], Optional[str]], List[Dict[str, Union[str, int]]]] = chunking_by_token_size# 分块大小chunk_token_size: int = 1200# 分块重叠数量chunk_overlap_token_size: int = 100# tiktoken使用的模型名字,默认为gpt-4otiktoken_model_name: str = "gpt-4o"# entity extraction# 实体提取最大“拾取”次数,也就是反复提取次数,默认不反复提取entity_extract_max_gleaning: int = 1# 实体摘要最大token数entity_summary_to_max_tokens: int = 500# graph clustering# 社区聚类算法,默认为莱顿算法graph_cluster_algorithm: str = "leiden"# 最大社区聚类节点数,默认为10max_graph_cluster_size: int = 10# 社区聚类随机种子,默认为0xDEADBEEFgraph_cluster_seed: int = 0xDEADBEEF# node embeddingnode_embedding_algorithm: str = "node2vec"# 如果没有显式传入 node2vec_params,则会调用这个 lambda 函数,自动生成并赋值为这个默认的字典node2vec_params: dict = field(default_factory=lambda: {"dimensions": 1536,"num_walks": 10,"walk_length": 40,"num_walks": 10,"window_size": 2,"iterations": 3,"random_seed": 3,})# community reports# 以json格式返回社区报告special_community_report_llm_kwargs: dict = field(default_factory=lambda: {"response_format": {"type": "json_object"}})# text embedding# 默认使用openai的embeddingembedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)# 批量大小,默认为32embedding_batch_num: int = 32# 最大并发请求数量embedding_func_max_async: int = 16# 用于比较查询结果质量的阈值query_better_than_threshold: float = 0.2# LLM相关调用# 需要两种类型的大语言模型(LLM),一种是高性能的,用于规划和回应;另一种是成本较低的,用于总结using_azure_openai: bool = Falsebest_model_func: callable = gpt_4o_completebest_model_max_token_size: int = 32768best_model_max_async: int = 16cheap_model_func: callable = gpt_4o_mini_completecheap_model_max_token_size: int = 32768cheap_model_max_async: int = 16# entity extraction# 实体提取函数entity_extraction_func: callable = extract_entities# storage# 存储类型设置# 键值存储,json,具体定义在_storage.pykey_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage# 向量数据库存储,具体定义在_storage.pyvector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage# 为向量数据库存储类提供可选的参数字典vector_db_storage_cls_kwargs: dict = field(default_factory=dict)# 图数据库存储,默认为NetworkXStoragegraph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage# 是否启用llm缓存enable_llm_cache: bool = True# extension# 用于传递额外参数的字典addon_params: dict = field(default_factory=dict)# 用于将 LLM 输出转换为 JSON 的函数convert_response_to_json_func: callable = convert_response_to_json# 在对象初始化后调用此方法,主要作用为打印配置信息和根据配置进行一些设置调整def __post_init__(self):# 将对象的属性以键值对的形式打印出来,用于调试和日志记录_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])logger.debug(f"GraphRAG init with param:\n\n {_print_config}\n")# 如果配置了使用Azure OpenAI服务,则调整相关函数为Azure版本if self.using_azure_openai:if self.best_model_func == gpt_4o_complete:self.best_model_func = azure_gpt_4o_completeif self.cheap_model_func == gpt_4o_mini_complete:self.cheap_model_func = azure_gpt_4o_mini_completeif self.embedding_func == openai_embedding:self.embedding_func = azure_openai_embeddinglogger.info("Switched the default openai funcs to Azure OpenAI if you didn't set any of it")# 确保工作目录存在,如果不存在则创建if not os.path.exists(self.working_dir):logger.info(f"Creating working directory {self.working_dir}")os.makedirs(self.working_dir)# 初始化存储类实例,用于存储完整文档self.full_docs = self.key_string_value_json_storage_cls(namespace="full_docs", global_config=asdict(self))# 初始化存储类实例,用于存储文本块self.text_chunks = self.key_string_value_json_storage_cls(namespace="text_chunks", global_config=asdict(self))# 根据配置初始化LLM响应缓存,如果配置为启用则创建缓存实例self.llm_response_cache = (self.key_string_value_json_storage_cls(namespace="llm_response_cache", global_config=asdict(self))if self.enable_llm_cacheelse None)# 初始化存储类实例,用于存储社区报告self.community_reports = self.key_string_value_json_storage_cls(namespace="community_reports", global_config=asdict(self))# 初始化图存储类实例,用于存储块实体关系图self.chunk_entity_relation_graph = self.graph_storage_cls(namespace="chunk_entity_relation", global_config=asdict(self))# 限制embedding函数的异步调用次数self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(self.embedding_func)# 根据配置初始化向量数据库存储类实例,用于存储实体self.entities_vdb = (self.vector_db_storage_cls(namespace="entities",global_config=asdict(self),embedding_func=self.embedding_func,meta_fields={"entity_name"},)if self.enable_localelse None)# 根据配置初始化向量数据库存储类实例,用于存储块self.chunks_vdb = (self.vector_db_storage_cls(namespace="chunks",global_config=asdict(self),embedding_func=self.embedding_func,)if self.enable_naive_ragelse None)# 限制最佳模型函数的异步调用次数,并为其配置哈希键值存储self.best_model_func = limit_async_func_call(self.best_model_max_async)(partial(self.best_model_func, hashing_kv=self.llm_response_cache))# 限制廉价模型函数的异步调用次数,并为其配置哈希键值存储self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(partial(self.cheap_model_func, hashing_kv=self.llm_response_cache))def insert(self, string_or_strings):loop = always_get_an_event_loop()return loop.run_until_complete(self.ainsert(string_or_strings))def query(self, query: str, param: QueryParam = QueryParam()):loop = always_get_an_event_loop()return loop.run_until_complete(self.aquery(query, param))async def aquery(self, query: str, param: QueryParam = QueryParam()):"""异步查询函数,根据指定的查询模式执行相应的查询。该函数支持三种查询模式:本地模式 ("local")、全局模式 ("global") 和朴素模式 ("naive")。根据提供的参数确定使用哪种查询模式,并在查询完成后执行查询结束的钩子函数。参数:- query: str, 要查询的字符串。- param: QueryParam, 查询参数对象,包含查询模式和其他查询相关的参数。返回:- response: 从查询函数返回的响应。异常:- ValueError: 当尝试在不支持的模式下执行查询时抛出。"""if param.mode == "local" and not self.enable_local:raise ValueError("enable_local is False, cannot query in local mode")if param.mode == "naive" and not self.enable_naive_rag:raise ValueError("enable_naive_rag is False, cannot query in local mode")# 根据查询模式执行相应的查询函数if param.mode == "local":response = await local_query(query,self.chunk_entity_relation_graph,self.entities_vdb,self.community_reports,self.text_chunks,param,asdict(self),)elif param.mode == "global":response = await global_query(query,self.chunk_entity_relation_graph,self.entities_vdb,self.community_reports,self.text_chunks,param,asdict(self),)elif param.mode == "naive":response = await naive_query(query,self.chunks_vdb,self.text_chunks,param,asdict(self),)else:raise ValueError(f"Unknown mode {param.mode}")await self._query_done()return responseasync def ainsert(self, string_or_strings):"""异步插入字符串或字符串列表到文档和片段数据库中,并更新知识图谱和社区报告。参数:string_or_strings (str 或 List[str]): 要插入的字符串或字符串列表。"""try:# 如果输入是一个字符串,将其转换为列表if isinstance(string_or_strings, str):string_or_strings = [string_or_strings]# ---------- new docs# 将字符串或字符串列表string_or_strings中的每个元素去除首尾空白后,作为文档内容。# 计算其MD5哈希值并添加前缀doc-作为键,内容本身作为值,生成一个新的字典new_docs。new_docs = {compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}for c in string_or_strings}# 筛选出需要添加的新文档ID_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))# 根据筛选结果更新新文档字典new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}# 如果没有新文档需要添加,记录日志并返回if not len(new_docs):logger.warning(f"All docs are already in the storage")return# 记录插入新文档的日志logger.info(f"[New Docs] inserting {len(new_docs)} docs")# ---------- chunkinginserting_chunks = {}for doc_key, doc in new_docs.items():# 为每个文档生成片段chunks = {compute_mdhash_id(dp["content"], prefix="chunk-"): {**dp,"full_doc_id": doc_key,}for dp in self.chunk_func(doc["content"],overlap_token_size=self.chunk_overlap_token_size,max_token_size=self.chunk_token_size,tiktoken_model=self.tiktoken_model_name,)}inserting_chunks.update(chunks)# 筛选出需要添加的新片段ID_add_chunk_keys = await self.text_chunks.filter_keys(list(inserting_chunks.keys()))# 根据筛选结果更新新片段字典inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}# 如果没有新片段需要添加,记录日志并返回if not len(inserting_chunks):logger.warning(f"All chunks are already in the storage")returnlogger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")# 如果启用了简单的RAG,则插入片段到相应的数据库if self.enable_naive_rag:logger.info("Insert chunks for naive RAG")await self.chunks_vdb.upsert(inserting_chunks)# 由于目前不支持增量更新社区,因此删除所有现有的社区报告# TODO: no incremental update for communities now, so just drop allawait self.community_reports.drop()# ---------- extract/summary entity and upsert to graph# 提取新实体和关系,并更新到知识图谱中logger.info("[Entity Extraction]...")maybe_new_kg = await self.entity_extraction_func(inserting_chunks,knwoledge_graph_inst=self.chunk_entity_relation_graph,entity_vdb=self.entities_vdb,global_config=asdict(self),)if maybe_new_kg is None:logger.warning("No new entities found")returnself.chunk_entity_relation_graph = maybe_new_kg# ---------- update clusterings of graph# 更新知识图谱的聚类logger.info("[Community Report]...")await self.chunk_entity_relation_graph.clustering(self.graph_cluster_algorithm)# 生成并更新社区报告await generate_community_report(self.community_reports, self.chunk_entity_relation_graph, asdict(self))# ---------- commit upsertings and indexing# 提交所有更新和索引操作await self.full_docs.upsert(new_docs)await self.text_chunks.upsert(inserting_chunks)finally:await self._insert_done()async def _insert_done(self):tasks = []for storage_inst in [self.full_docs,self.text_chunks,self.llm_response_cache,self.community_reports,self.entities_vdb,self.chunks_vdb,self.chunk_entity_relation_graph,]:if storage_inst is None:continuetasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())await asyncio.gather(*tasks)async def _query_done(self):"""异步函数,用于在查询完成后,确保所有相关的存储实例的索引更新操作完成。1. 遍历预定义的存储实例列表(当前只有一个llm_response_cache)。2. 对于每个非空的存储实例,添加其索引更新完成的回调函数到任务列表。3. 使用asyncio.gather等待所有添加到任务列表中的回调函数执行完成。该函数的主要作用是确保在查询完成后,所有配置的存储实例都完成了它们的索引更新操作,这对于保持数据一致性和完整性至关重要。"""tasks = []for storage_inst in [self.llm_response_cache]:if storage_inst is None:continue# 将存储实例的索引更新完成的回调函数添加到任务列表tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())await asyncio.gather(*tasks)
3.使用测试
这里我采用本地embedding和调用kimi作为大模型来测试结果,根据官方例子编写了如下的代码。
python">import os
import sys
import logging
from openai import AsyncOpenAI
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
from sentence_transformers import SentenceTransformer
import numpy as np
sys.path.append("..")logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.INFO)API_KEY = "填你自己的"
MODEL = "moonshot-v1-32k"
URL = "https://api.moonshot.cn/v1"
QUESTION = "请概括故事的主要情节并分析故事的主旨。"
WORKING_DIR = "./workspace"
FILEPATH = "./txt/追逐雪的人.txt"
EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
)async def model_if_cache(prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:openai_async_client = AsyncOpenAI(api_key=API_KEY, base_url=URL)messages = []if system_prompt:messages.append({"role": "system", "content": system_prompt})# Get the cached response if having-------------------hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)messages.extend(history_messages)messages.append({"role": "user", "content": prompt})if hashing_kv is not None:args_hash = compute_args_hash(MODEL, messages)if_cache_return = await hashing_kv.get_by_id(args_hash)if if_cache_return is not None:return if_cache_return["return"]# -----------------------------------------------------response = await openai_async_client.chat.completions.create(model=MODEL, messages=messages, **kwargs)# Cache the response if having-------------------if hashing_kv is not None:await hashing_kv.upsert({args_hash: {"return": response.choices[0].message.content, "model": MODEL}})# -----------------------------------------------------return response.choices[0].message.contentdef remove_if_exist(file):if os.path.exists(file):os.remove(file)def query():rag = GraphRAG(working_dir=WORKING_DIR,best_model_func=model_if_cache,cheap_model_func=model_if_cache,embedding_func=local_embedding)print(rag.query(QUESTION, param=QueryParam(mode="global")))@wrap_embedding_func_with_attrs(embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),max_token_size=EMBED_MODEL.max_seq_length,
)async def local_embedding(texts: list[str]) -> np.ndarray:return EMBED_MODEL.encode(texts, normalize_embeddings=True)def insert():from time import timewith open(FILEPATH, encoding="utf-8-sig") as f:FAKE_TEXT = f.read()remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")rag = GraphRAG(working_dir=WORKING_DIR,enable_llm_cache=True,best_model_func=model_if_cache,cheap_model_func=model_if_cache,embedding_func=local_embedding)start = time()rag.insert(FAKE_TEXT)print("indexing time:", time() - start)if __name__ == "__main__":insert()query()
回答结果如下
### 故事主要情节概括
故事主要围绕两个核心人物——"她"和"我"——以及他们的宠物"花卷"展开。这两个人物通过共同的活动、
深入的对话和相互关怀来加深彼此的情感联系。他们一起经历了各种活动和情感交流,包括参观博物馆、
讨论艺术作品,以及参与社区中的艺术和教育活动。宠物"花卷"在故事中扮演了情感纽带的角色,增强了
两位主角之间的联系,并且是他们共同关心和照顾的对象。
### 故事主旨分析
1. **人际关系与情感联系**:故事的主旨在于探索人与人之间的深层情感联系,以及在共同经历和相互
支持中成长和发现生活的意义。"她"和"我"之间的关系是故事的核心,他们通过亲密的旅程,强调了在面
对生活中的挑战和困难时,人与人之间的相互支持和关怀的重要性。
2. **社区互动与个人成长**:故事还展现了社区互动和个人成长的重要性。一个男孩在社区中的互动和
成长经历,涉及学校活动、艺术任务和个人关系,表明社区成员之间的互动对个人发展具有重要影响。艺
术和教育活动在促进社区凝聚力和个人发展中发挥了重要作用。
3. **艺术与美的欣赏**:故事中包含了对艺术和美的欣赏,两位主角在参观博物馆和讨论艺术作品时展
现了他们的情感智慧和共同的感性。艺术活动,特别是海报着色任务,是连接社区成员的重要活动,促进
社区参与和艺术表达。
4. **教育与权威角色**:老师在社区中具有双重角色,既是权威人物,又是教育活动的促进者,管理课
堂行为并支持艺术任务的完成。这表明教育者在塑造社区文化和促进个人成长中扮演着关键角色。
5. **导师与钦佩动态**:男孩与俱乐部领袖的关系表明社区内存在强烈的导师或钦佩动态,影响男孩在
社区中的动机和参与度。这种关系对个人成长和社区凝聚力具有积极影响。
综上所述,故事通过"她"、"我"和宠物"花卷"的亲密旅程,以及他们在社区中的互动和成长经历,探讨了
人际关系、社区互动、艺术欣赏、教育角色和导师动态等多重主题,展现了人与人之间的情感联系、个人
成长和社区凝聚力的重要性。
生成的文件均为json格式,知识图谱是graphml格式。
构建的知识图谱如下图。