Spaces:
Sleeping
Sleeping
# %% | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
import gradio as gr | |
import hanzidentifier | |
import re | |
import chinese_converter | |
import pathlib | |
current_path=str(pathlib.Path(__file__).parent.resolve()) | |
# %% | |
#Load the LLM model and pipeline directly | |
llm_model_name="Qwen/Qwen1.5-0.5B-Chat" | |
#pipe = pipeline("text2text-generation", model=model) | |
model = AutoModelForCausalLM.from_pretrained( | |
llm_model_name | |
) | |
#model = AutoPeftModelForCausalLM.from_pretrained( | |
# "Qwen1.5_0.5B_Chat_sft_full/checkpoint-300", | |
# low_cpu_mem_usage=True, | |
#) | |
tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
# %% | |
# %% | |
# loading the vector encoder | |
vec_model_name = "shibing624/text2vec-base-chinese" | |
encode_kwargs = {'normalize_embeddings': False} | |
model_kwargs = {'device': 'cpu'} | |
huggingface_embeddings= HuggingFaceEmbeddings( | |
model_name=vec_model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs = encode_kwargs | |
) | |
# %% | |
persist_directory = 'chroma/' | |
vectordb = Chroma(embedding_function=huggingface_embeddings,persist_directory=persist_directory) | |
print(vectordb._collection.count()) | |
# %% | |
text_input_label=["谜面","謎面","Riddle"] | |
text_output_label=["谜底","謎底","Answer"] | |
clear_label = ["清除","清除","Clear"] | |
submit_label = ["提交","提交","Submit"] | |
threshold = 0.6 | |
# %% | |
# helper functions for prompt processing for this LLM | |
# def preprocess(text): | |
# text = text.replace("\n", "\\n").replace("\t", "\\t") | |
# return text | |
# def postprocess(text): | |
# return text.replace("\\n", "\n").replace("\\t", "\t").replace('%20',' ') | |
# get answer from LLM with prompt input | |
def answer(input_text,context=None): | |
if context: | |
tips = "提示:\n" | |
for i, tip in enumerate(context): | |
#if i==0: | |
# tips +="最佳答案\n" | |
#else: | |
# tips +="較差答案\n" | |
tips += f"{i+1}. 谜面:{tip[0]} 谜底是:{tip[1]} " | |
tips +="\n" | |
print (f"====\n{input_text}\n{context[0][0]} 谜底是:{context[0][1]} {context[0][2]}") | |
if context[0][2] >=0.9: | |
return f"谜底是:{context[0][1]}" | |
else: | |
tips="" | |
prompt = f"{input_text}\n\n{tips}\n\n谜底是什么?" | |
prompt = prompt.strip() | |
print(f"===\n{prompt}") | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu") | |
generated_ids = model.generate( | |
model_inputs.input_ids, | |
max_new_tokens=128, | |
do_sample=False, | |
temperature=0.0 | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
#return out_text[0]["generated_text"] | |
return response | |
#return postprocess(out_text[0]["generated_text"]) | |
# helper function for RAG | |
def helper_rag(text): | |
docs_out = vectordb.similarity_search_with_relevance_scores(text,k=3) | |
#docs_out = vectordb.max_marginal_relevance_search(text,k=5,fetch_k = 20, lambda_mult = 0.5) | |
context = [] | |
for doc in docs_out: | |
if doc[1] > threshold: | |
context.append((doc[0].page_content, doc[0].metadata['answer'], doc[1])) | |
return context | |
# helper function for prompt | |
def helper_text(text_input,radio=None): | |
chinese_type = "simplified" | |
if hanzidentifier.is_traditional(text_input): | |
text_input = chinese_converter.to_simplified(text_input) | |
chinese_type = "traditional" | |
text_input = re.sub(r'hint:',"猜",text_input,flags=re.I) | |
#if not any(c in text_input for c in ["猜", "打"]): | |
# warning = "请给一个提示,提示格式,例子:猜一水果,打一字。" | |
# if chinese_type == "traditional" or radio == "繁體中文": | |
# warning = chinese_converter.to_traditional(warning) | |
# return warning | |
text=f"""猜谜语:\n谜面:{text_input}""" | |
context = helper_rag(text_input) | |
output = answer(text,context=context) | |
print(output) | |
if chinese_type == "traditional": | |
output = chinese_converter.to_traditional(output) | |
#output = re.split(r'\s+',output) | |
return output | |
#return output[0] | |
# get answer from LLM with prompt input | |
def translate(input_text): | |
'''Use LLM for translation''' | |
prompt = f"""翻译以下內容成英语: | |
{input_text} | |
""" | |
print(prompt) | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(device="cpu") | |
generated_ids = model.generate( | |
model_inputs.input_ids, | |
max_new_tokens=128, | |
do_sample=False, | |
top_p=0.0 | |
) | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
#return out_text[0]["generated_text"] | |
return response | |
#return postprocess(out_text[0]["generated_text"]) | |
# Gradio function for configure the language of UI | |
def change_language(radio,text_input,text_output,markdown, | |
markdown_msg1, markdown_msg2,translate_btn): | |
if radio == "简体中文": | |
index = 0 | |
text_input=gr.Textbox(value = chinese_converter.to_simplified(text_input), label = text_input_label[index]) | |
text_output=gr.Textbox(value = chinese_converter.to_simplified(text_output),label = text_output_label[index]) | |
markdown=chinese_converter.to_simplified(markdown) | |
markdown_msg1=chinese_converter.to_simplified(markdown_msg1) | |
markdown_msg2=chinese_converter.to_simplified(markdown_msg2) | |
translate_btn=gr.Button(visible=False) | |
elif radio == "繁體中文": | |
index = 1 | |
text_input=gr.Textbox(value = chinese_converter.to_traditional(text_input),label = text_input_label[index]) | |
text_output=gr.Textbox(value = chinese_converter.to_traditional(text_output),label = text_output_label[index]) | |
markdown=chinese_converter.to_traditional(markdown) | |
markdown_msg1=chinese_converter.to_traditional(markdown_msg1) | |
markdown_msg2=chinese_converter.to_traditional(markdown_msg2) | |
translate_btn=gr.Button(visible=False) | |
elif radio == "English": | |
index = 2 | |
text_input=gr.Textbox(label = text_input_label[index]) | |
text_output=gr.Textbox(label = text_output_label[index]) | |
translate_btn=gr.Button(visible=True) | |
else: | |
index = 0 | |
text_input=gr.Textbox(label = text_input_label[index]) | |
text_output=gr.Textbox(label = text_output_label[index]) | |
markdown=chinese_converter.to_simplified(markdown) | |
markdown_msg1=chinese_converter.to_simplified(markdown_msg1) | |
markdown_msg2=chinese_converter.to_simplified(markdown_msg2) | |
translate_btn=gr.Button(visible=False) | |
clear_btn = clear_label[index] | |
submit_btn = submit_label[index] | |
return [text_input,text_output,clear_btn,submit_btn,markdown, | |
markdown_msg1 ,markdown_msg2,translate_btn] | |
def clear_text(): | |
text_input_update="" | |
text_output_update="" | |
return [text_input_update,text_output_update] | |
def translate_text(text_input,text_output): | |
text_input = translate(f"{text_input}") | |
text_output = translate(f"{text_output}") | |
return text_input,text_output | |
# %% | |
# css = """ | |
# #markdown { background-image: url("file/data/DSC_0105.jpg"); | |
# background-size: cover; | |
# } | |
# """ | |
with gr.Blocks() as demo: | |
index = 0 | |
example_list = [ | |
["小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)"], | |
["一物生来身穿三百多件衣,每天脱一件,年底剩张皮。(猜一物品)"], | |
["A thousand threads, a million strands. Reaching the water, vanishing all at once. (Hint: natural phenomenon)"], | |
["无底洞(猜一成语)"], | |
] | |
radio = gr.Radio( | |
["简体中文","繁體中文", "English"],show_label=False,value="简体中文" | |
) | |
markdown = gr.Markdown( | |
""" | |
# Chinese Lantern Riddles Solver with LLM | |
## 用大语言模型来猜灯谜 | |
""",elem_id="markdown") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox(label=text_input_label[index], | |
value="小家伙穿黄袍,花丛中把房造。飞到西来飞到东,人人夸他爱劳动。(猜一动物)", lines = 2) | |
with gr.Row(): | |
clear_btn = gr.ClearButton(value=clear_label[index],components=[text_input]) | |
submit_btn = gr.Button(value=submit_label[index], variant = "primary") | |
text_output = gr.Textbox(label=text_output_label[index]) | |
translate_btn = gr.Button(value="Translate", variant = "primary", scale=0, visible=False) | |
examples = gr.Examples( | |
examples=example_list, | |
inputs=text_input, | |
outputs=text_output, | |
fn=helper_text, | |
cache_examples=True, | |
) | |
markdown_msg1 = gr.Markdown( | |
""" | |
灯谜是中华文化特色文娱活动,自北宋盛行。每年逢正月十五元宵节,将谜语贴在花灯上,让大家可一起猜谜。 | |
Lantern riddle is a traditional Chinese cultural activity. Being popular since the Song Dynasty (960-1276), it \ | |
is held in the Lantern Festival (15th day of the first lunar month). \ | |
When people are viewing the flower lanterns, they can guess the riddles on the lanterns together. | |
""" | |
) | |
with gr.Column(): | |
markdown_msg2 = gr.Markdown( | |
""" | |
 | |
--- | |
# 声明 Disclaimer | |
本应用输出的文本为机器基于模型生成的结果,不代表任何人观点,请谨慎辨别和参考。请在法律允许的范围内使用。 | |
本应用调用了 [Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat) 对话语言大模型,\ | |
使用本应用前请务必阅读和同意遵守其[使用授权许可证](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE)。 | |
本应用仅供非商业用途。 | |
The outputs of this application are machine-generated with a statistical model. \ | |
The outputs do not reflect any opinions of any human subjects. You must identify the outputs in caution. \ | |
It is your responsbility to decide whether to accept the outputs. You must use the applicaiton in obedience to the Law. | |
This application utilizes [Qwen1.5-0.5B-Chat](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat) \ | |
Conversational Large Language Model. Before using this application, you must read and accept to follow \ | |
the [LICENSE](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE). | |
This application is for non-commercial use only. | |
--- | |
# 感谢 Acknowledgement | |
本应用调用了 [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) 生成 text vector embeddings. | |
该模型是以 [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) 发行。 | |
This application utilizes [text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) to generate text vector embeddings. | |
The model is released under [apache-2.0](https://www.apache.org/licenses/LICENSE-2.0)。 | |
""") | |
submit_btn.click(fn=helper_text, inputs=[text_input,radio], outputs=text_output) | |
translate_btn.click(fn=translate_text, inputs=[text_input,text_output], outputs=[text_input,text_output]) | |
clear_btn.click(fn=clear_text,outputs=[text_input,text_output]) | |
radio.change(fn=change_language,inputs=[radio,text_input,text_output, | |
markdown, markdown_msg1,markdown_msg2,translate_btn], | |
outputs=[text_input,text_output,clear_btn,submit_btn, | |
markdown, markdown_msg1,markdown_msg2,translate_btn]) | |
#demo = gr.Interface(fn=helper_text, inputs=text_input, outputs=text_output, | |
# flagging_options=["Inappropriate"],allow_flagging="never", | |
# title="aaa",description="aaa",article="aaa") | |
#demo.queue(api_open=False) | |
demo.launch(show_api=False,allowed_paths=[current_path+"/data/"]) | |
# %% | |