|
import spaces |
|
import os |
|
|
|
from huggingface_hub import Repository |
|
from huggingface_hub import login |
|
|
|
init_feedback = False |
|
|
|
try: |
|
login(token = os.environ['HUB_TOKEN']) |
|
|
|
repo = Repository( |
|
local_dir="backend_fn", |
|
repo_type="dataset", |
|
clone_from=os.environ['DATASET'], |
|
token=True, |
|
git_email='[email protected]' |
|
) |
|
repo.git_pull() |
|
|
|
init_feedback = True |
|
except: |
|
pass |
|
|
|
import json |
|
import uuid |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from threading import Thread |
|
|
|
if init_feedback: |
|
from backend_fn.feedback import feedback |
|
|
|
from gradio_modal import Modal |
|
|
|
""" |
|
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference |
|
""" |
|
model_name = "Merdeka-LLM/merdeka-llm-hr-3b-128k-instruct" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True) |
|
|
|
histories = [] |
|
action = None |
|
feedback_index = None |
|
|
|
session_id = uuid.uuid1().__str__() |
|
|
|
@spaces.GPU |
|
def respond( |
|
message, |
|
history: list[tuple[str, str]], |
|
|
|
max_tokens = 4096, |
|
temperature = 0.01, |
|
top_p = 0.95, |
|
): |
|
messages = [ |
|
{"role": "system", "content": "You are a professional Human Resource advisor who is familiar with HR related Malaysia Law."} |
|
] |
|
|
|
for val in history: |
|
if val[0]: |
|
messages.append({"role": "user", "content": val[0]}) |
|
if val[1]: |
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
response = "" |
|
|
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
) |
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
generate_kwargs = dict( |
|
model_inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
streamer=streamer |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
for new_token in streamer: |
|
if new_token != '<': |
|
response += new_token |
|
yield response |
|
|
|
""" |
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
""" |
|
|
|
def submit_feedback(value): |
|
feedback(session_id, json.dumps(histories), value, action, feedback_index) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
def vote(history,data: gr.LikeData): |
|
global histories |
|
global action |
|
global feedback_index |
|
histories = history |
|
action = data.liked |
|
feedback_index = data.index[0] |
|
|
|
with Modal(visible=False) as modal: |
|
textb = gr.Textbox( |
|
label='Actual response', |
|
info='Leave blank if the answer is good enough' |
|
) |
|
|
|
submit_btn = gr.Button( |
|
'Submit' |
|
) |
|
|
|
submit_btn.click(submit_feedback,textb) |
|
submit_btn.click(lambda: Modal(visible=False), None, modal) |
|
submit_btn.click(lambda x: gr.update(value=''), [],[textb]) |
|
|
|
|
|
ci = gr.ChatInterface( |
|
respond, |
|
description='Due to an unknown bug in Gradio, we are unable to expand the conversation section to full height.' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
ci.chatbot.show_copy_button=True |
|
|
|
|
|
|
|
if init_feedback: |
|
ci.chatbot.like(vote, ci.chatbot, None).then( |
|
lambda: Modal(visible=True), None, modal |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
|
|
) |
|
|