Browse Source

完善API接口,完善模型加载 (#247)

* 完善知识库路径问题,完善api接口

统一webui、API接口知识库路径,后续路径如下:
知识库路经就是:/项目代码文件夹/vector_store/'知识库名字'
文件存放路经:/项目代码文件夹/content/'知识库名字'

修复通过api接口创建知识库的BUG,完善API接口功能。

* Update model_config.py


* 完善知识库路径问题,完善api接口 (#245) (#246)

* Fix 知识库无法上载,NLTK_DATA_PATH路径错误 (#236)

* Update chatglm_llm.py (#242)

* 完善知识库路径问题,完善api接口

统一webui、API接口知识库路径,后续路径如下:
知识库路经就是:/项目代码文件夹/vector_store/'知识库名字'
文件存放路经:/项目代码文件夹/content/'知识库名字'

修复通过api接口创建知识库的BUG,完善API接口功能。

* Update model_config.py

---------

Co-authored-by: shrimp <411161555@qq.com>
Co-authored-by: Bob Chang <bob-chang@outlook.com>

* 优化API接口,完善模型top_p参数

优化API接口,知识库非必须选项。
完善模型top_p参数

* 完善API接口,完善模型加载

API接口知识库非必须加载项
完善模型top_p参数。

---------

Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
Co-authored-by: Bob Chang <bob-chang@outlook.com>
shrimp 2 years ago
parent
commit
0d9db37f45
2 changed files with 25 additions and 19 deletions
  1. 21 17
      api.py
  2. 4 2
      models/chatglm_llm.py

+ 21 - 17
api.py

@@ -170,32 +170,36 @@ async def delete_docs(
 
 
 async def chat(
-        knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
-        question: str = Body(..., description="Question", example="工伤保险是什么?"),
+        knowledge_base_id: str = Body(..., description="知识库名字", example="kb1"),
+        question: str = Body(..., description="问题", example="工伤保险是什么?"),
         history: List[List[str]] = Body(
             [],
-            description="History of previous questions and answers",
+            description="问题及答案的历史记录",
             example=[
                 [
-                    "工伤保险是什么?",
-                    "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
+                    "这里是问题,如:工伤保险是什么?",
+                    "答案:工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
                 ]
             ],
         ),
 ):
     vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
-    if not os.path.exists(vs_path):
-        raise ValueError(f"Knowledge base {knowledge_base_id} not found")
-
-    for resp, history in local_doc_qa.get_knowledge_based_answer(
-            query=question, vs_path=vs_path, chat_history=history, streaming=True
-    ):
-        pass
-    source_documents = [
-        f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
-        f"""相关度:{doc.metadata['score']}\n\n"""
-        for inum, doc in enumerate(resp["source_documents"])
-    ]
+    resp = {}
+    if os.path.exists(vs_path) and knowledge_base_id:
+        for resp, history in local_doc_qa.get_knowledge_based_answer(
+                query=question, vs_path=vs_path, chat_history=history, streaming=False
+        ):
+            pass
+        source_documents = [
+            f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
+            f"""相关度:{doc.metadata['score']}\n\n"""
+            for inum, doc in enumerate(resp["source_documents"])
+        ]
+    else:
+        for resp_s, history in local_doc_qa.llm._call(prompt=question, history=history, streaming=False):
+            pass
+        resp["result"] = resp_s
+        source_documents =[("当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。")]
 
     return ChatMessage(
         question=question,

+ 4 - 2
models/chatglm_llm.py

@@ -43,7 +43,7 @@ def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
 
 class ChatGLM(LLM):
     max_token: int = 10000
-    temperature: float = 0.01
+    temperature: float = 0.8
     top_p = 0.9
     # history = []
     tokenizer: object = None
@@ -68,6 +68,7 @@ class ChatGLM(LLM):
                     history=history[-self.history_len:-1] if self.history_len > 0 else [],
                     max_length=self.max_token,
                     temperature=self.temperature,
+                    top_p=self.top_p,
             )):
                 torch_gc()
                 if inum == 0:
@@ -83,6 +84,7 @@ class ChatGLM(LLM):
                 history=history[-self.history_len:] if self.history_len > 0 else [],
                 max_length=self.max_token,
                 temperature=self.temperature,
+                top_p=self.top_p,
             )
             torch_gc()
             history += [[prompt, response]]
@@ -141,7 +143,7 @@ class ChatGLM(LLM):
                 from accelerate import dispatch_model
 
                 model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True,
-                        config=model_config, **kwargs)
+                                                  config=model_config, **kwargs)
                 if LLM_LORA_PATH and use_lora:
                     from peft import PeftModel
                     model = PeftModel.from_pretrained(model, LLM_LORA_PATH)