|
@@ -2,7 +2,6 @@ from langchain.chains import RetrievalQA
|
|
|
from langchain.prompts import PromptTemplate
|
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
|
from langchain.vectorstores import FAISS
|
|
|
-from langchain.vectorstores.base import VectorStoreRetriever
|
|
|
from langchain.document_loaders import UnstructuredFileLoader
|
|
|
from models.chatglm_llm import ChatGLM
|
|
|
import sentence_transformers
|
|
@@ -34,22 +33,20 @@ def load_file(filepath):
|
|
|
docs = loader.load_and_split(text_splitter=textsplitter)
|
|
|
return docs
|
|
|
|
|
|
-
|
|
|
-def get_relevant_documents(self, query: str) -> List[Document]:
|
|
|
- if self.search_type == "similarity":
|
|
|
- docs = self.vectorstore._similarity_search_with_relevance_scores(query, **self.search_kwargs)
|
|
|
- for doc in docs:
|
|
|
- doc[0].metadata["score"] = doc[1]
|
|
|
- docs = [doc[0] for doc in docs]
|
|
|
- elif self.search_type == "mmr":
|
|
|
- docs = self.vectorstore.max_marginal_relevance_search(
|
|
|
- query, **self.search_kwargs
|
|
|
- )
|
|
|
- else:
|
|
|
- raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
|
- return docs
|
|
|
+def generate_prompt(related_docs: List[str],
|
|
|
+ query: str,
|
|
|
+ prompt_template=PROMPT_TEMPLATE) -> str:
|
|
|
+ context = "\n".join([doc.page_content for doc in related_docs])
|
|
|
+ prompt = prompt_template.replace("{question}", query).replace("{context}", context)
|
|
|
+ return prompt
|
|
|
|
|
|
|
|
|
+def get_docs_with_score(docs_with_score):
|
|
|
+ docs=[]
|
|
|
+ for doc, score in docs_with_score:
|
|
|
+ doc.metadata["score"] = score
|
|
|
+ docs.append(doc)
|
|
|
+ return docs
|
|
|
|
|
|
class LocalDocQA:
|
|
|
llm: object = None
|
|
@@ -73,8 +70,6 @@ class LocalDocQA:
|
|
|
|
|
|
self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[embedding_model],
|
|
|
model_kwargs={'device': embedding_device})
|
|
|
- # self.embeddings.client = sentence_transformers.SentenceTransformer(self.embeddings.model_name,
|
|
|
- # device=embedding_device)
|
|
|
self.top_k = top_k
|
|
|
|
|
|
def init_knowledge_vector_store(self,
|
|
@@ -134,34 +129,30 @@ class LocalDocQA:
|
|
|
def get_knowledge_based_answer(self,
|
|
|
query,
|
|
|
vs_path,
|
|
|
- chat_history=[], ):
|
|
|
- prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
|
|
|
- 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
|
|
|
-
|
|
|
- 已知内容:
|
|
|
- {context}
|
|
|
-
|
|
|
- 问题:
|
|
|
- {question}"""
|
|
|
- prompt = PromptTemplate(
|
|
|
- template=prompt_template,
|
|
|
- input_variables=["context", "question"]
|
|
|
- )
|
|
|
- self.llm.history = chat_history
|
|
|
+ chat_history=[],
|
|
|
+ streaming=True):
|
|
|
+ self.llm.streaming = streaming
|
|
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
|
- vs_r = vector_store.as_retriever(search_type="mmr",
|
|
|
- search_kwargs={"k": self.top_k})
|
|
|
- # VectorStoreRetriever.get_relevant_documents = get_relevant_documents
|
|
|
- knowledge_chain = RetrievalQA.from_llm(
|
|
|
- llm=self.llm,
|
|
|
- retriever=vs_r,
|
|
|
- prompt=prompt
|
|
|
- )
|
|
|
- knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
|
|
|
- input_variables=["page_content"], template="{page_content}"
|
|
|
- )
|
|
|
+ related_docs_with_score = vector_store.similarity_search_with_score(query,
|
|
|
+ k=self.top_k)
|
|
|
+ related_docs = get_docs_with_score(related_docs_with_score)
|
|
|
+ prompt = generate_prompt(related_docs, query)
|
|
|
|
|
|
- knowledge_chain.return_source_documents = True
|
|
|
- result = knowledge_chain({"query": query})
|
|
|
- self.llm.history[-1][0] = query
|
|
|
- return result, self.llm.history
|
|
|
+ if streaming:
|
|
|
+ for result, history in self.llm._call(prompt=prompt,
|
|
|
+ history=chat_history):
|
|
|
+ history[-1] = list(history[-1])
|
|
|
+ history[-1][0] = query
|
|
|
+ response = {"query": query,
|
|
|
+ "result": result,
|
|
|
+ "source_documents": related_docs}
|
|
|
+ yield response, history
|
|
|
+ else:
|
|
|
+ result, history = self.llm._call(prompt=prompt,
|
|
|
+ history=chat_history)
|
|
|
+ history[-1] = list(history[-1])
|
|
|
+ history[-1][0] = query
|
|
|
+ response = {"query": query,
|
|
|
+ "result": result,
|
|
|
+ "source_documents": related_docs}
|
|
|
+ return response, history
|