soon 2 жил өмнө
parent
commit
37ceeae6e2

+ 18 - 15
chains/local_doc_qa.py

@@ -117,22 +117,25 @@ class LocalDocQA:
     
     问题:
     {question}"""
-        prompt = PromptTemplate(
-            template=prompt_template,
-            input_variables=["context", "question"]
-        )
-        self.llm.history = chat_history
-        vector_store = FAISS.load_local(vs_path, self.embeddings)
-        knowledge_chain = RetrievalQA.from_llm(
-            llm=self.llm,
-            retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
-            prompt=prompt
-        )
-        knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
-            input_variables=["page_content"], template="{page_content}"
-        )
+        if  vs_path is None or vs_path =="":# or (not os.path.exists(vs_path))
+            result = self.llm.chat(query)
+        else:
+            prompt = PromptTemplate(
+                template=prompt_template,
+                input_variables=["context", "question"]
+            )
+            self.llm.history = chat_history
+            vector_store = FAISS.load_local(vs_path, self.embeddings)
+            knowledge_chain = RetrievalQA.from_llm(
+                llm=self.llm,
+                retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),
+                prompt=prompt
+            )
+            knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
+                input_variables=["page_content"], template="{page_content}"
+            )
 
-        knowledge_chain.return_source_documents = True
+            knowledge_chain.return_source_documents = True
 
         result = knowledge_chain({"query": query})
         self.llm.history[-1][0] = query

+ 2 - 1
configs/model_config.py

@@ -19,10 +19,11 @@ llm_model_dict = {
     "chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
     "chatglm-6b-int4": "THUDM/chatglm-6b-int4",
     "chatglm-6b": "THUDM/chatglm-6b",
+    "chatyuan": "ClueAI/ChatYuan-large-v2",
 }
 
 # LLM model name
-LLM_MODEL = "chatglm-6b"
+LLM_MODEL = "chatyuan" #"chatglm-6b"
 
 # Use p-tuning-v2 PrefixEncoder
 USE_PTUNING_V2 = False

+ 13 - 0
models/chatglm_llm.py

@@ -82,6 +82,19 @@ class ChatGLM(LLM):
         self.history = self.history+[[None, response]]
         return response
 
+    def chat(self,
+              prompt: str) -> str:
+        response, _ = self.model.chat(
+            self.tokenizer,
+            prompt,
+            history=[],#self.history[-self.history_len:] if self.history_len>0 else 
+            max_length=self.max_token,
+            temperature=self.temperature,
+        )
+        torch_gc()
+        self.history = self.history+[[None, response]]
+        return response
+        
     def load_model(self,
                    model_name_or_path: str = "THUDM/chatglm-6b",
                    llm_device=LLM_DEVICE,