Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,50 @@ from transformers import pipeline
|
|
3 |
import logging
|
4 |
import torch
|
5 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Logging Setup
|
8 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
3 |
import logging
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
+
from huggingface_hub import InferenceClient
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
10 |
+
|
11 |
+
def format_prompt(message, history):
|
12 |
+
prompt = "[Instuction]:\n"
|
13 |
+
for user_prompt, bot_response in history:
|
14 |
+
prompt += f"\n{user_prompt}: {bot_response}"
|
15 |
+
|
16 |
+
prompt += f"\n{message}<|endoftext|>\n"
|
17 |
+
return prompt
|
18 |
+
|
19 |
+
def generate(input_text, history, system_prompt, temperature, max_length, top_p, repetition_penalty):
|
20 |
+
input_text = input_text.strip()
|
21 |
+
history = history[:-1] if input_text == '[continue conversation]' else history
|
22 |
+
history = list(filter(lambda x: x[0] != '', history))
|
23 |
+
|
24 |
+
system_prompt = system_prompt.replace("[INST]", "[Instruction]").strip() + "\n"
|
25 |
+
prompt = format_prompt(input_text, history)
|
26 |
+
|
27 |
+
result = client.generate(
|
28 |
+
prompt,
|
29 |
+
max_length=int(max_length),
|
30 |
+
temperature=float(temperature),
|
31 |
+
top_p=float(top_p),
|
32 |
+
repetition_penalty=float(repetition_penalty),
|
33 |
+
num_return_sequences=1,
|
34 |
+
do_sample=True
|
35 |
+
)[0]['generated_text'].strip()
|
36 |
+
|
37 |
+
return {"history": history + [(input_text, result)], "result": result}
|
38 |
+
|
39 |
+
iface = gr.Interface(fn=generate,
|
40 |
+
inputs=gr.Inputs(text="input_text",
|
41 |
+
textarea="system_prompt",
|
42 |
+
sliders={"temperature": (0.0, 1.0, 0.1),
|
43 |
+
"max_length": (20, 256, 1),
|
44 |
+
"top_p": (0.1, 1.0, 0.1),
|
45 |
+
"repetition_penalty": (1.0, 2.0, 0.1)}),
|
46 |
+
outputs="markdown",
|
47 |
+
interpretation="notext",
|
48 |
+
examples={"My first question": ["input_text": "How old are you?", "system_prompt": "", "temperature": 0.5, "max_length": 50, "top_p": 0.5, "repetition_penalty": 1.1}},
|
49 |
+
allow_flagging="never")
|
50 |
|
51 |
# Logging Setup
|
52 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|