소스 검색

add stream support to webui.py

imClumsyPanda 2 년 전
부모
커밋
966def8cfe
5개의 변경된 파일25개의 추가작업 그리고 268개의 파일을 삭제
  1. 0 103
      api.py
  2. 0 2
      chains/local_doc_qa.py
  3. 0 142
      chains/test.ipynb
  4. 7 4
      models/chatglm_llm.py
  5. 18 17
      webui.py

+ 0 - 103
api.py

@@ -1,103 +0,0 @@
-from configs.model_config import *
-from chains.local_doc_qa import LocalDocQA
-import os
-import nltk
-
-import uvicorn
-from fastapi import FastAPI, File, UploadFile
-from pydantic import BaseModel
-from starlette.responses import RedirectResponse
-
-app = FastAPI()
-
-global local_doc_qa, vs_path
-
-nltk.data.path = [os.path.join(os.path.dirname(__file__), "nltk_data")] + nltk.data.path
-
-# return top-k text chunk from vector store
-VECTOR_SEARCH_TOP_K = 10
-
-# LLM input history length
-LLM_HISTORY_LEN = 3
-
-# Show reply with source text from input document
-REPLY_WITH_SOURCE = False
-
-class Query(BaseModel):
-    query: str
-
-@app.get('/')
-async def document():
-    return RedirectResponse(url="/docs")
-
-@app.on_event("startup")
-async def get_local_doc_qa():
-    global local_doc_qa
-    local_doc_qa = LocalDocQA()
-    local_doc_qa.init_cfg(llm_model=LLM_MODEL,
-                          embedding_model=EMBEDDING_MODEL,
-                          embedding_device=EMBEDDING_DEVICE,
-                          llm_history_len=LLM_HISTORY_LEN,
-                          top_k=VECTOR_SEARCH_TOP_K)
-    
-
-@app.post("/file")
-async def upload_file(UserFile: UploadFile=File(...),):
-    global vs_path
-    response = {
-        "msg": None,
-        "status": 0
-    }
-    try:
-        filepath = './content/' + UserFile.filename
-        content = await UserFile.read()
-        # print(UserFile.filename)
-        with open(filepath, 'wb') as f:
-            f.write(content)
-        vs_path, files = local_doc_qa.init_knowledge_vector_store(filepath)
-        response = {
-            'msg': 'seccess' if len(files)>0 else 'fail',
-            'status': 1 if len(files)>0 else 0,
-            'loaded_files': files
-        }
-        
-    except Exception as err:
-        response["message"] = err
-        
-    return response 
-
-@app.post("/qa")
-async def get_answer(query: str = ""):
-    response = {
-        "status": 0,
-        "message": "",
-        "answer": None
-    }
-    global vs_path
-    history = []
-    try:
-        resp, history = local_doc_qa.get_knowledge_based_answer(query=query,
-                                                                vs_path=vs_path,
-                                                                chat_history=history)
-        if REPLY_WITH_SOURCE:
-            response["answer"] = resp
-        else:
-            response['answer'] = resp["result"]
-        
-        response["message"] = 'successful'
-        response["status"] = 1
-
-    except Exception as err:
-        response["message"] = err
-        
-    return response
-
-
-if __name__ == "__main__":
-    uvicorn.run(
-        app=app,
-        host='0.0.0.0', 
-        port=8100,
-        reload=True,
-        )
-

+ 0 - 2
chains/local_doc_qa.py

@@ -141,7 +141,6 @@ class LocalDocQA:
         if streaming:
         if streaming:
             for result, history in self.llm._call(prompt=prompt,
             for result, history in self.llm._call(prompt=prompt,
                                                   history=chat_history):
                                                   history=chat_history):
-                history[-1] = list(history[-1])
                 history[-1][0] = query
                 history[-1][0] = query
                 response = {"query": query,
                 response = {"query": query,
                             "result": result,
                             "result": result,
@@ -150,7 +149,6 @@ class LocalDocQA:
         else:
         else:
             result, history = self.llm._call(prompt=prompt,
             result, history = self.llm._call(prompt=prompt,
                                              history=chat_history)
                                              history=chat_history)
-            history[-1] = list(history[-1])
             history[-1][0] = query
             history[-1][0] = query
             response = {"query": query,
             response = {"query": query,
                         "result": result,
                         "result": result,

파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
+ 0 - 142
chains/test.ipynb


+ 7 - 4
models/chatglm_llm.py

@@ -74,14 +74,17 @@ class ChatGLM(LLM):
               history: List[List[str]] = [],
               history: List[List[str]] = [],
               stop: Optional[List[str]] = None) -> str:
               stop: Optional[List[str]] = None) -> str:
         if self.streaming:
         if self.streaming:
-            history = history + [[None, ""]]
-            for stream_resp, history in self.model.stream_chat(
+            for inum, (stream_resp, _) in enumerate(self.model.stream_chat(
                     self.tokenizer,
                     self.tokenizer,
                     prompt,
                     prompt,
-                    history=history[-self.history_len:] if self.history_len > 0 else [],
+                    history=history[-self.history_len:-1] if self.history_len > 0 else [],
                     max_length=self.max_token,
                     max_length=self.max_token,
                     temperature=self.temperature,
                     temperature=self.temperature,
-            ):
+            )):
+                if inum == 0:
+                    history += [[prompt, stream_resp]]
+                else:
+                    history[-1] = [prompt, stream_resp]
                 yield stream_resp, history
                 yield stream_resp, history
 
 
         else:
         else:

+ 18 - 17
webui.py

@@ -33,23 +33,23 @@ def get_answer(query, vs_path, history, mode):
     if mode == "知识库问答":
     if mode == "知识库问答":
         if vs_path:
         if vs_path:
             for resp, history in local_doc_qa.get_knowledge_based_answer(
             for resp, history in local_doc_qa.get_knowledge_based_answer(
-                query=query, vs_path=vs_path, chat_history=history):
-    #         source = "".join([f"""<details> <summary>出处 {i + 1}</summary>
-    # {doc.page_content}
-    #
-    # <b>所属文件:</b>{doc.metadata["source"]}
-    # </details>""" for i, doc in enumerate(resp["source_documents"])])
-    #         history[-1][-1] += source
+                    query=query, vs_path=vs_path, chat_history=history):
+                source = "\n\n"
+                source += "".join(
+                    [f"""<details> <summary>出处 [{i + 1}] {os.path.split(doc.metadata["source"])[-1]}</summary>\n"""
+                     f"""{doc.page_content}\n"""
+                     f"""</details>"""
+                     for i, doc in
+                     enumerate(resp["source_documents"])])
+                history[-1][-1] += source
                 yield history, ""
                 yield history, ""
         else:
         else:
-            history = history + [[query, ""]]
-            for resp in local_doc_qa.llm._call(query):
+            for resp, history in local_doc_qa.llm._call(query, history):
                 history[-1][-1] = resp + (
                 history[-1][-1] = resp + (
                     "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
                     "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
                 yield history, ""
                 yield history, ""
     else:
     else:
-        history = history + [[query, ""]]
-        for resp in local_doc_qa.llm._call(query):
+        for resp, history in local_doc_qa.llm._call(query, history):
             history[-1][-1] = resp
             history[-1][-1] = resp
             yield history, ""
             yield history, ""
 
 
@@ -269,9 +269,10 @@ with gr.Blocks(css=block_css) as demo:
                             outputs=chatbot
                             outputs=chatbot
                             )
                             )
 
 
-demo.queue(concurrency_count=3
-           ).launch(server_name='0.0.0.0',
-                    server_port=7860,
-                    show_api=False,
-                    share=False,
-                    inbrowser=False)
+(demo
+ .queue(concurrency_count=3)
+ .launch(server_name='0.0.0.0',
+         server_port=7860,
+         show_api=False,
+         share=False,
+         inbrowser=False))

이 변경점에서 너무 많은 파일들이 변경되어 몇몇 파일들은 표시되지 않았습니다.