|
@@ -10,11 +10,6 @@ from langchain.docstore.document import Document
|
|
import numpy as np
|
|
import numpy as np
|
|
from utils import torch_gc
|
|
from utils import torch_gc
|
|
|
|
|
|
-# return top-k text chunk from vector store
|
|
|
|
-VECTOR_SEARCH_TOP_K = 6
|
|
|
|
-
|
|
|
|
-# LLM input history length
|
|
|
|
-LLM_HISTORY_LEN = 3
|
|
|
|
|
|
|
|
DEVICE_ = EMBEDDING_DEVICE
|
|
DEVICE_ = EMBEDDING_DEVICE
|
|
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
|
DEVICE_ID = "0" if torch.cuda.is_available() else None
|
|
@@ -109,7 +104,7 @@ def similarity_search_with_score_by_vector(
|
|
if not isinstance(doc, Document):
|
|
if not isinstance(doc, Document):
|
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
|
raise ValueError(f"Could not find document for id {_id}, got {doc}")
|
|
docs.append((doc, scores[0][j]))
|
|
docs.append((doc, scores[0][j]))
|
|
- torch_gc(DEVICE)
|
|
|
|
|
|
+ torch_gc()
|
|
return docs
|
|
return docs
|
|
|
|
|
|
|
|
|
|
@@ -181,13 +176,13 @@ class LocalDocQA:
|
|
if vs_path and os.path.isdir(vs_path):
|
|
if vs_path and os.path.isdir(vs_path):
|
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
vector_store = FAISS.load_local(vs_path, self.embeddings)
|
|
vector_store.add_documents(docs)
|
|
vector_store.add_documents(docs)
|
|
- torch_gc(DEVICE)
|
|
|
|
|
|
+ torch_gc()
|
|
else:
|
|
else:
|
|
if not vs_path:
|
|
if not vs_path:
|
|
vs_path = os.path.join(VS_ROOT_PATH,
|
|
vs_path = os.path.join(VS_ROOT_PATH,
|
|
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
|
f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
|
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
|
vector_store = FAISS.from_documents(docs, self.embeddings)
|
|
- torch_gc(DEVICE)
|
|
|
|
|
|
+ torch_gc()
|
|
|
|
|
|
vector_store.save_local(vs_path)
|
|
vector_store.save_local(vs_path)
|
|
return vs_path, loaded_files
|
|
return vs_path, loaded_files
|
|
@@ -206,6 +201,7 @@ class LocalDocQA:
|
|
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
|
related_docs_with_score = vector_store.similarity_search_with_score(query,
|
|
k=self.top_k)
|
|
k=self.top_k)
|
|
related_docs = get_docs_with_score(related_docs_with_score)
|
|
related_docs = get_docs_with_score(related_docs_with_score)
|
|
|
|
+ torch_gc()
|
|
prompt = generate_prompt(related_docs, query)
|
|
prompt = generate_prompt(related_docs, query)
|
|
|
|
|
|
# if streaming:
|
|
# if streaming:
|
|
@@ -220,11 +216,13 @@ class LocalDocQA:
|
|
for result, history in self.llm._call(prompt=prompt,
|
|
for result, history in self.llm._call(prompt=prompt,
|
|
history=chat_history,
|
|
history=chat_history,
|
|
streaming=streaming):
|
|
streaming=streaming):
|
|
|
|
+ torch_gc()
|
|
history[-1][0] = query
|
|
history[-1][0] = query
|
|
response = {"query": query,
|
|
response = {"query": query,
|
|
"result": result,
|
|
"result": result,
|
|
"source_documents": related_docs}
|
|
"source_documents": related_docs}
|
|
yield response, history
|
|
yield response, history
|
|
|
|
+ torch_gc()
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
@@ -244,9 +242,4 @@ if __name__ == "__main__":
|
|
for inum, doc in
|
|
for inum, doc in
|
|
enumerate(resp["source_documents"])]
|
|
enumerate(resp["source_documents"])]
|
|
print("\n\n" + "\n\n".join(source_text))
|
|
print("\n\n" + "\n\n".join(source_text))
|
|
- # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
|
|
|
|
- # vs_path=vs_path,
|
|
|
|
- # chat_history=[],
|
|
|
|
- # streaming=False):
|
|
|
|
- # print(resp["result"])
|
|
|
|
pass
|
|
pass
|