Ver Fonte

update torch_gc

imClumsyPanda há 2 anos atrás
pai
commit
07ff81a119
7 ficheiros alterados com 28 adições e 43 exclusões
  1. 3 9
      api.py
  2. 6 13
      chains/local_doc_qa.py
  3. 1 7
      cli_demo.py
  4. 9 1
      configs/model_config.py
  5. 4 2
      models/chatglm_llm.py
  6. 4 4
      utils/__init__.py
  7. 1 7
      webui.py

+ 3 - 9
api.py

@@ -16,16 +16,10 @@ from typing_extensions import Annotated
 
 
 from chains.local_doc_qa import LocalDocQA
 from chains.local_doc_qa import LocalDocQA
 from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
 from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
-                                  EMBEDDING_MODEL, LLM_MODEL)
-
-nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
-
-# return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 6
-
-# LLM input history length
-LLM_HISTORY_LEN = 3
+                                  EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
+                                  VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
 
 
+nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
 
 
 class BaseResponse(BaseModel):
 class BaseResponse(BaseModel):
     code: int = pydantic.Field(200, description="HTTP status code")
     code: int = pydantic.Field(200, description="HTTP status code")

+ 6 - 13
chains/local_doc_qa.py

@@ -10,11 +10,6 @@ from langchain.docstore.document import Document
 import numpy as np
 import numpy as np
 from utils import torch_gc
 from utils import torch_gc
 
 
-# return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 6
-
-# LLM input history length
-LLM_HISTORY_LEN = 3
 
 
 DEVICE_ = EMBEDDING_DEVICE
 DEVICE_ = EMBEDDING_DEVICE
 DEVICE_ID = "0" if torch.cuda.is_available() else None
 DEVICE_ID = "0" if torch.cuda.is_available() else None
@@ -109,7 +104,7 @@ def similarity_search_with_score_by_vector(
         if not isinstance(doc, Document):
         if not isinstance(doc, Document):
             raise ValueError(f"Could not find document for id {_id}, got {doc}")
             raise ValueError(f"Could not find document for id {_id}, got {doc}")
         docs.append((doc, scores[0][j]))
         docs.append((doc, scores[0][j]))
-    torch_gc(DEVICE)
+    torch_gc()
     return docs
     return docs
 
 
 
 
@@ -181,13 +176,13 @@ class LocalDocQA:
             if vs_path and os.path.isdir(vs_path):
             if vs_path and os.path.isdir(vs_path):
                 vector_store = FAISS.load_local(vs_path, self.embeddings)
                 vector_store = FAISS.load_local(vs_path, self.embeddings)
                 vector_store.add_documents(docs)
                 vector_store.add_documents(docs)
-                torch_gc(DEVICE)
+                torch_gc()
             else:
             else:
                 if not vs_path:
                 if not vs_path:
                     vs_path = os.path.join(VS_ROOT_PATH,
                     vs_path = os.path.join(VS_ROOT_PATH,
                                            f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
                                            f"""{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""")
                 vector_store = FAISS.from_documents(docs, self.embeddings)
                 vector_store = FAISS.from_documents(docs, self.embeddings)
-                torch_gc(DEVICE)
+                torch_gc()
 
 
             vector_store.save_local(vs_path)
             vector_store.save_local(vs_path)
             return vs_path, loaded_files
             return vs_path, loaded_files
@@ -206,6 +201,7 @@ class LocalDocQA:
         related_docs_with_score = vector_store.similarity_search_with_score(query,
         related_docs_with_score = vector_store.similarity_search_with_score(query,
                                                                             k=self.top_k)
                                                                             k=self.top_k)
         related_docs = get_docs_with_score(related_docs_with_score)
         related_docs = get_docs_with_score(related_docs_with_score)
+        torch_gc()
         prompt = generate_prompt(related_docs, query)
         prompt = generate_prompt(related_docs, query)
 
 
         # if streaming:
         # if streaming:
@@ -220,11 +216,13 @@ class LocalDocQA:
         for result, history in self.llm._call(prompt=prompt,
         for result, history in self.llm._call(prompt=prompt,
                                               history=chat_history,
                                               history=chat_history,
                                               streaming=streaming):
                                               streaming=streaming):
+            torch_gc()
             history[-1][0] = query
             history[-1][0] = query
             response = {"query": query,
             response = {"query": query,
                         "result": result,
                         "result": result,
                         "source_documents": related_docs}
                         "source_documents": related_docs}
             yield response, history
             yield response, history
+            torch_gc()
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
@@ -244,9 +242,4 @@ if __name__ == "__main__":
                    for inum, doc in
                    for inum, doc in
                    enumerate(resp["source_documents"])]
                    enumerate(resp["source_documents"])]
     print("\n\n" + "\n\n".join(source_text))
     print("\n\n" + "\n\n".join(source_text))
-    # for resp, history in local_doc_qa.get_knowledge_based_answer(query=query,
-    #                                                              vs_path=vs_path,
-    #                                                              chat_history=[],
-    #                                                              streaming=False):
-    #     print(resp["result"])
     pass
     pass

+ 1 - 7
cli_demo.py

@@ -3,13 +3,7 @@ from chains.local_doc_qa import LocalDocQA
 import os
 import os
 import nltk
 import nltk
 
 
-nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
-
-# return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 6
-
-# LLM input history length
-LLM_HISTORY_LEN = 3
+nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
 
 
 # Show reply with source text from input document
 # Show reply with source text from input document
 REPLY_WITH_SOURCE = True
 REPLY_WITH_SOURCE = True

+ 9 - 1
configs/model_config.py

@@ -49,4 +49,12 @@ PROMPT_TEMPLATE = """已知信息:
 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
 根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
 
 
 # 匹配后单段上下文长度
 # 匹配后单段上下文长度
-CHUNK_SIZE = 500
+CHUNK_SIZE = 250
+
+# LLM input history length
+LLM_HISTORY_LEN = 3
+
+# return top-k text chunk from vector store
+VECTOR_SEARCH_TOP_K = 5
+
+NLTK_DATA_PATH = os.path.join(os.path.dirname(__file__), "nltk_data")

+ 4 - 2
models/chatglm_llm.py

@@ -69,12 +69,13 @@ class ChatGLM(LLM):
                     max_length=self.max_token,
                     max_length=self.max_token,
                     temperature=self.temperature,
                     temperature=self.temperature,
             )):
             )):
-                torch_gc(DEVICE)
+                torch_gc()
                 if inum == 0:
                 if inum == 0:
                     history += [[prompt, stream_resp]]
                     history += [[prompt, stream_resp]]
                 else:
                 else:
                     history[-1] = [prompt, stream_resp]
                     history[-1] = [prompt, stream_resp]
                 yield stream_resp, history
                 yield stream_resp, history
+                torch_gc()
         else:
         else:
             response, _ = self.model.chat(
             response, _ = self.model.chat(
                     self.tokenizer,
                     self.tokenizer,
@@ -83,9 +84,10 @@ class ChatGLM(LLM):
                     max_length=self.max_token,
                     max_length=self.max_token,
                     temperature=self.temperature,
                     temperature=self.temperature,
             )
             )
-            torch_gc(DEVICE)
+            torch_gc()
             history += [[prompt, response]]
             history += [[prompt, response]]
             yield response, history
             yield response, history
+            torch_gc()
 
 
     # def chat(self,
     # def chat(self,
     #          prompt: str) -> str:
     #          prompt: str) -> str:

+ 4 - 4
utils/__init__.py

@@ -1,10 +1,10 @@
 import torch
 import torch
 
 
-def torch_gc(DEVICE):
+def torch_gc():
     if torch.cuda.is_available():
     if torch.cuda.is_available():
-        with torch.cuda.device(DEVICE):
-            torch.cuda.empty_cache()
-            torch.cuda.ipc_collect()
+        # with torch.cuda.device(DEVICE):
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
     elif torch.backends.mps.is_available():
     elif torch.backends.mps.is_available():
         try:
         try:
             from torch.mps import empty_cache
             from torch.mps import empty_cache

+ 1 - 7
webui.py

@@ -5,13 +5,7 @@ from chains.local_doc_qa import LocalDocQA
 from configs.model_config import *
 from configs.model_config import *
 import nltk
 import nltk
 
 
-nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
-
-# return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 6
-
-# LLM input history length
-LLM_HISTORY_LEN = 3
+nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
 
 
 
 
 def get_vs_list():
 def get_vs_list():