ghengx commited on
Commit
2fc39d8
1 Parent(s): bbf672d
Files changed (2) hide show
  1. app.py +164 -0
  2. requirements.txt +65 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+
4
+ from huggingface_hub import Repository
5
+ from huggingface_hub import login
6
+
7
+ init_feedback = False
8
+
9
+ try:
10
+ login(token = os.environ['HUB_TOKEN'])
11
+
12
+ repo = Repository(
13
+ local_dir="backend_fn",
14
+ repo_type="dataset",
15
+ clone_from=os.environ['DATASET'],
16
+ token=True,
17
+ git_email='[email protected]'
18
+ )
19
+ repo.git_pull()
20
+
21
+ init_feedback = True
22
+ except:
23
+ pass
24
+
25
+ import json
26
+ import uuid
27
+ import gradio as gr
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
29
+ from threading import Thread
30
+
31
+ if init_feedback:
32
+ from backend_fn.feedback import feedback
33
+
34
+ from gradio_modal import Modal
35
+
36
+ """
37
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
38
+ """
39
+ model_name = "Merdeka-LLM/merdeka-llm-hr-3b-128k-instruct"
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model_name,
43
+ torch_dtype="auto",
44
+ device_map="auto"
45
+ )
46
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
47
+
48
+ streamer = TextIteratorStreamer(tokenizer, timeout=300, skip_prompt=True, skip_special_tokens=True)
49
+
50
+ histories = []
51
+ action = None
52
+ feedback_index = None
53
+
54
+ session_id = uuid.uuid1().__str__()
55
+
56
+ @spaces.GPU
57
+ def respond(
58
+ message,
59
+ history: list[tuple[str, str]],
60
+ # system_message,
61
+ max_tokens = 4096,
62
+ temperature = 0.01,
63
+ top_p = 0.95,
64
+ ):
65
+ messages = [
66
+ {"role": "system", "content": "You are a professional lawyer who is familiar with Malaysia Law."}
67
+ ]
68
+
69
+ for val in history:
70
+ if val[0]:
71
+ messages.append({"role": "user", "content": val[0]})
72
+ if val[1]:
73
+ messages.append({"role": "assistant", "content": val[1]})
74
+
75
+ messages.append({"role": "user", "content": message})
76
+
77
+ response = ""
78
+
79
+ text = tokenizer.apply_chat_template(
80
+ messages,
81
+ tokenize=False,
82
+ add_generation_prompt=True,
83
+ )
84
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
85
+
86
+ generate_kwargs = dict(
87
+ model_inputs,
88
+ max_new_tokens=max_tokens,
89
+ temperature=temperature,
90
+ top_p=top_p,
91
+ streamer=streamer
92
+ )
93
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
94
+ t.start()
95
+ for new_token in streamer:
96
+ if new_token != '<':
97
+ response += new_token
98
+ yield response
99
+
100
+ """
101
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
102
+ """
103
+
104
+ def submit_feedback(value):
105
+ feedback(session_id, json.dumps(histories), value, action, feedback_index)
106
+
107
+
108
+ with gr.Blocks() as demo:
109
+ def vote(history,data: gr.LikeData):
110
+ global histories
111
+ global action
112
+ global feedback_index
113
+ histories = history
114
+ action = data.liked
115
+ feedback_index = data.index[0]
116
+
117
+ with Modal(visible=False) as modal:
118
+ textb = gr.Textbox(
119
+ label='Actual response',
120
+ info='Leave blank if the answer is good enough'
121
+ )
122
+
123
+ submit_btn = gr.Button(
124
+ 'Submit'
125
+ )
126
+
127
+ submit_btn.click(submit_feedback,textb)
128
+ submit_btn.click(lambda: Modal(visible=False), None, modal)
129
+ submit_btn.click(lambda x: gr.update(value=''), [],[textb])
130
+
131
+
132
+ ci = gr.ChatInterface(
133
+ respond,
134
+ description='Due to an unknown bug in Gradio, we are unable to expand the conversation section to full height.'
135
+ # fill_height=True
136
+ # additional_inputs=[
137
+ # # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
138
+ # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
139
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
140
+ # gr.Slider(
141
+ # minimum=0.1,
142
+ # maximum=1.0,
143
+ # value=0.95,
144
+ # step=0.05,
145
+ # label="Top-p (nucleus sampling)",
146
+ # ),
147
+ # ],
148
+ )
149
+
150
+
151
+ ci.chatbot.show_copy_button=True
152
+ # ci.chatbot.value=[(None,"Hello! I'm here to assist you with understanding the laws and acts of Malaysia.")]
153
+ # ci.chatbot.height=500
154
+
155
+ if init_feedback:
156
+ ci.chatbot.like(vote, ci.chatbot, None).then(
157
+ lambda: Modal(visible=True), None, modal
158
+ )
159
+
160
+ if __name__ == "__main__":
161
+ demo.launch(
162
+
163
+ )
164
+
requirements.txt ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.0.1
2
+ aiofiles==23.2.1
3
+ annotated-types==0.7.0
4
+ anyio==4.6.2.post1
5
+ certifi==2024.8.30
6
+ charset-normalizer==3.4.0
7
+ click==8.1.7
8
+ fastapi==0.115.4
9
+ ffmpy==0.4.0
10
+ filelock==3.16.1
11
+ fsspec==2024.10.0
12
+ gradio==5.4.0
13
+ gradio_client==1.4.2
14
+ gradio_modal==0.0.4
15
+ h11==0.14.0
16
+ httpcore==1.0.6
17
+ httpx==0.27.2
18
+ huggingface-hub==0.26.2
19
+ idna==3.10
20
+ Jinja2==3.1.4
21
+ markdown-it-py==3.0.0
22
+ MarkupSafe==2.1.5
23
+ mdurl==0.1.2
24
+ mpmath==1.3.0
25
+ networkx==3.4.2
26
+ numpy==1.26.4
27
+ orjson==3.10.10
28
+ packaging==24.1
29
+ pandas==2.2.3
30
+ pillow==11.0.0
31
+ psutil==5.9.8
32
+ pydantic==2.9.2
33
+ pydantic_core==2.23.4
34
+ pydub==0.25.1
35
+ Pygments==2.18.0
36
+ PyMySQL==1.1.1
37
+ python-dateutil==2.9.0.post0
38
+ python-multipart==0.0.12
39
+ pytz==2024.2
40
+ PyYAML==6.0.2
41
+ regex==2024.9.11
42
+ requests==2.32.3
43
+ rich==13.9.3
44
+ ruff==0.7.1
45
+ safehttpx==0.1.1
46
+ safetensors==0.4.5
47
+ semantic-version==2.10.0
48
+ setuptools==75.3.0
49
+ shellingham==1.5.4
50
+ six==1.16.0
51
+ sniffio==1.3.1
52
+ spaces==0.30.4
53
+ starlette==0.41.2
54
+ sympy==1.13.1
55
+ tokenizers==0.20.1
56
+ tomlkit==0.12.0
57
+ torch==2.2.0
58
+ tqdm==4.66.6
59
+ transformers==4.46.1
60
+ typer==0.12.5
61
+ typing_extensions==4.12.2
62
+ tzdata==2024.2
63
+ urllib3==2.2.3
64
+ uvicorn==0.32.0
65
+ websockets==12.0