File size: 3,650 Bytes
6ca8627
cacdd22
b324113
 
95e7c6b
e501604
1570310
 
 
6520b77
a871520
6520b77
 
 
0252582
 
d6ec152
0252582
 
6520b77
 
 
 
 
 
 
 
 
cacdd22
58504c7
951a4ca
58504c7
6520b77
3e04659
 
 
 
 
 
9a8c838
6520b77
b324113
41a17c2
 
7ca0615
41a17c2
994ccce
41a17c2
a20e394
5c43704
41a17c2
9860b3d
3e04659
9860b3d
3e04659
 
9860b3d
a20e394
 
9860b3d
a20e394
 
 
951a4ca
cacdd22
8e2baa5
994ccce
a20e394
951a4ca
41a17c2
994ccce
a20e394
951a4ca
41a17c2
994ccce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
#from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import WebBaseLoader
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from bs4 import BeautifulSoup   #WebBaseLoader会需要 用到?
from langchain import HuggingFaceHub
import requests
import sys
from huggingface_hub import InferenceClient

import os
from dotenv import load_dotenv
load_dotenv()
hf_token = os.environ.get('HUGGINGFACEHUB_API_TOKEN')
repo_id=os.environ.get('repo_id')

#port = os.getenv('port')

#OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')

llm = HuggingFaceHub(repo_id=repo_id,  # for StarChat
                     huggingfacehub_api_token=hf_token,  #这个变量huggingfacehub_api_token名称似乎没有问题!
                     model_kwargs={"min_length": 512,  # for StarChat
                                   "max_new_tokens": 1024, "do_sample": True,  # for StarChat
                                   "temperature": 0.01,
                                   "top_k": 50,
                                   "top_p": 0.95, "eos_token_id": 49155})

#chain = load_summarize_chain(llm, chain_type="stuff")    #stuff模式容易导致出错:估计是超LLM的token限制所致

chain = load_summarize_chain(llm, chain_type="refine")

#text_splitter_rcs = RecursiveCharacterTextSplitter(
#    #separator = "\n", #TypeError: TextSplitter.__init__() got an unexpected keyword argument 'separator'
#    chunk_size = 500,
#    chunk_overlap  = 100, #striding over the text
#    length_function = len,
#    )

#llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-16k")

url=st.text_input("Enter webiste URL to summarize (format: https://www.usinoip.com):")

if url !="" and not url.strip().isspace() and not url == "" and not url.strip() == "" and not url.isspace():
    try:
        #loader = WebBaseLoader("https://www.usinoip.com/")
        with st.spinner("AI Thinking...Please wait a while to Cheers!"):
            print("Website to Chat: "+url)             
            loader = WebBaseLoader(url)
            docs = loader.load()
            print("Webpage contents loaded")
            #split_docs = text_splitter_rcs.split_documents(docs)
            #print(split_docs)            
            result=chain.run(docs)   #这个result的格式比较特殊,可以直接print,但不可以和其他字符串联合print输出 - this step errors!
            #result=chain.run(split_docs)   #找到之前总是POST Error的原因:chain.run(docs)的结果,格式不是str,导致程序错误
            print("Chain run finished")
            result=str(result)  
            cleaned_initial_ai_response = remove_context(result)
            print("Ai Resposne result cleaned initially: "+cleaned_initial_ai_response)            
            final_ai_response = cleaned_initial_ai_response.split('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
            new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip()
            final_result = new_final_ai_response.split('Note:')[0].strip()   
            
            #print("AI Summarization: "+result)   #这个会出错,原因见上方
            print("AI Summarization:")
            #print(result)
            print(final_result)     
            
            st.write("AI Summarization:")
            #st.write(result)
            st.write(final_result)
            
    except Exception as e:
        st.write("Wrong URL or URL not parsable.")