找回密码
 立即注册
首页 业界区 安全 使用Langchain生成本地rag知识库并搭载大模型 ...

使用Langchain生成本地rag知识库并搭载大模型

轩辕娅童 2025-10-1 16:43:51
准备设备: 手机+aidlux2.0个人版

一、下载依赖

pip install langchain langchain-community faiss-cpu pypdf
二、安装ollama并下载模型
  1. curl -fsSL https://ollama.com/install.sh | sh #需要VPN
  2. ollama serve & #让ollama服务在后台运行
复制代码
安装完毕可以查看ollama版本进行验证,出现版本号之后就可以使用ollama
ollama -v
1.png

考虑性能因素,选择下载较小的模型
  1. ollama pull phi3:mini
  2. ollama pull all-minilm
复制代码
三、构建rag知识库


  • 打开手机上的aidlux应用,打开Cloud_ip查看网络ip,输入ip到浏览器+端口号:8000访问
    输入以下命令:
  1. cd ~
  2. touch build_knowledge_base.py
复制代码

  • 在文件浏览器中/home/aidlux 下找到对应py文件并打开
2.png


  • 自行准备一个知识库文本(txt或pdf),将文本的路径填入脚本中
3.png


  • 写入以下脚本内容
  1. from langchain_community.document_loaders import PyPDFLoader, TextLoader
  2. from langchain_text_splitters import RecursiveCharacterTextSplitter
  3. from langchain_community.embeddings import OllamaEmbeddings
  4. from langchain_community.vectorstores import FAISS
  5. import os
  6. # 1. 设置环境变量优化 Ollama 性能
  7. os.environ["OLLAMA_NUM_THREADS"] = "8"  # 设置线程数
  8. os.environ["OLLAMA_NUM_CTX"] = "2048"   # 设置上下文长度
  9. # 2. 配置嵌入模型 - 移除无效参数
  10. embeddings = OllamaEmbeddings(
  11.     model="all-minilm"  # 仅保留必要参数
  12. )
  13. # 3. 加载文档
  14. def load_documents(file_path):
  15.     if file_path.endswith(".pdf"):
  16.         loader = PyPDFLoader(file_path)
  17.         print(f"加载 PDF 文档: {file_path}")
  18.     elif file_path.endswith(".txt"):
  19.         loader = TextLoader(file_path)
  20.         print(f"加载文本文档: {file_path}")
  21.     else:
  22.         raise ValueError(f"不支持的文档格式: {file_path}")
  23.     return loader.load()
  24. # 4. 文本分割
  25. def split_documents(docs):
  26.     text_splitter = RecursiveCharacterTextSplitter(
  27.         chunk_size=500,
  28.         chunk_overlap=80,
  29.         separators=["\n\n", "\n", "。", "!", "?", ";"]
  30.     )
  31.     return text_splitter.split_documents(docs)
  32. # 5. 主函数
  33. def main():
  34.     # 示例文档 - 修改为您的文件路径
  35.     document_path = "knowledge.txt"
  36.    
  37.     # 加载和分割文档
  38.     print("开始处理文档...")
  39.     documents = load_documents(document_path)
  40.     chunks = split_documents(documents)
  41.     print(f"文档分割完成: 共 {len(chunks)} 个文本块")
  42.    
  43.     # 创建向量存储
  44.     print("开始生成嵌入向量...")
  45.     vector_store = FAISS.from_documents(
  46.         documents=chunks,
  47.         embedding=embeddings
  48.     )
  49.    
  50.     # 保存知识库索引
  51.     save_path = "my_knowledge_base"
  52.     vector_store.save_local(save_path)
  53.     print(f"知识库构建完成! 保存到: {save_path}")
  54.     print(f"向量库大小: {len(vector_store.index_to_docstore_id)} 个向量")
  55. if __name__ == "__main__":
  56.     main()
复制代码

  • 运行脚本
python3 build_knowledge_base.py
4.png

四、创建 RAG 问答系统


  • 创建一个脚本
touch rag_query.py

  • 写入以下内容
  1. from langchain_community.llms import Ollama
  2. from langchain_community.embeddings import OllamaEmbeddings
  3. from langchain_community.vectorstores import FAISS
  4. from langchain_core.prompts import ChatPromptTemplate
  5. from langchain_core.runnables import RunnablePassthrough
  6. from langchain_core.output_parsers import StrOutputParser
  7. import os
  8. import time
  9. import sys
  10. import select
  11. # 1. 通过环境变量设置优化参数
  12. os.environ["OLLAMA_NUM_THREADS"] = "8"  # 设置线程数
  13. os.environ["OLLAMA_NUM_CTX"] = "2048"   # 设置上下文长度
  14. # 2. 初始化模型
  15. llm = Ollama(
  16.     model="phi3:mini",        # 轻量级语言模型
  17.     temperature=0.3,           # 平衡创造性和准确性
  18.     timeout=120.0              # 设置超时时间
  19.         )
  20. embeddings = OllamaEmbeddings(model="all-minilm")
  21. # 3. 加载知识库
  22. try:
  23.     vector_store = FAISS.load_local(
  24.         "my_knowledge_base",
  25.         embeddings,
  26.         allow_dangerous_deserialization=True
  27.     )
  28.     retriever = vector_store.as_retriever(search_kwargs={"k": 3})
  29.     print("知识库加载成功")
  30. except Exception as e:
  31.     print(f"加载知识库失败: {str(e)}")
  32.     print("请确保已运行 build_knowledge_base.py 构建知识库")
  33.     exit(1)
  34. # 4. 定义提示模板
  35. template = """你是一个专业的知识库助手,请基于以下上下文回答问题。
  36. 如果不知道答案,请说"我不知道",不要编造答案。
  37. 上下文:
  38. {context}
  39. 问题:{question}
  40. 请用中文给出详细回答:"""
  41. prompt = ChatPromptTemplate.from_template(template)
  42. # 5. 构建 RAG 链
  43. rag_chain = (
  44.     {"context": retriever, "question": RunnablePassthrough()}
  45.     | prompt
  46.     | llm
  47.     | StrOutputParser()
  48. )
  49. # 6. 格式化文档显示
  50. def format_docs(docs):
  51.     return "\n\n".join(doc.page_content for doc in docs)
  52. # 7. 改进的输入函数(解决输入卡住问题)
  53. def get_user_input(prompt, timeout=60):
  54.     print(prompt, end='', flush=True)
  55.    
  56.     # 使用 select 检测输入可用性
  57.     if select.select([sys.stdin], [], [], timeout)[0]:
  58.         return sys.stdin.readline().strip()
  59.     return None
  60. # 8. 交互式问答
  61. print("知识库问答系统已启动(输入 'exit' 退出)")
  62. while True:
  63.     try:
  64.         # 使用改进的输入函数
  65.         query = get_user_input("\n你的问题:")
  66.         
  67.         if query is None:
  68.             print("\n输入超时,请重新输入...")
  69.             continue
  70.             
  71.         if query.lower() == "exit":
  72.             break
  73.         
  74.         start_time = time.time()
  75.         
  76.         # 显示检索到的参考内容
  77.         relevant_docs = retriever.invoke(query)
  78.         print("\n[检索到的参考内容]")
  79.         for i, doc in enumerate(relevant_docs[:2]):  # 显示前2个相关片段
  80.             print(f"\n片段 {i+1}:\n{doc.page_content[:200]}...")
  81.         
  82.         # 生成答案
  83.         response = rag_chain.invoke(query)
  84.         
  85.         end_time = time.time()
  86.         
  87.         print(f"\n[答案] (耗时:{end_time - start_time:.2f}秒)")
  88.         print(response)
  89.         
  90.         # 确保输出缓冲区刷新
  91.         sys.stdout.flush()
  92.         
  93.     except KeyboardInterrupt:
  94.         print("\n退出系统...")
  95.         break
  96.     except Exception as e:
  97.         print(f"处理问题时出错: {str(e)}")
  98.         print("请尝试简化问题或稍后重试")
  99.         # 清除可能的输入缓冲区残留
  100.         sys.stdin.readline()
复制代码
五、测试验证

python3 rag_query.py
5.png

根据提示词输入
6.png


来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册