mathstral_test / app.py
MarcdeFalco's picture
Migrate to transformers
ce78f1b verified
raw
history blame
3.78 kB
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import gradio as gr
import torch
import os
device = "cuda"
model_name = "mistralai/mathstral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
torch_dtype=torch.float16).to(device)
HF_TOKEN = os.environ['HF_TOKEN']
def format_prompt(message, history):
prompt = ""
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response} "
prompt += f"[INST] {message} [/INST]"
return prompt
@spaces.GPU
def generate(prompt, history,
max_new_tokens=1024,
repetition_penalty=1.2):
formatted_prompt = format_prompt(prompt, history)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer)
generate_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
text = ''
n = len('<s>') + len(formatted_prompt)
for word in streamer:
text += word
yield text[n:]
return text[n:]
additional_inputs=[
gr.Slider(
label="Max new tokens",
value=1024,
minimum=0,
maximum=4096,
step=256,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
),
]
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Mathstral Test</center><h1>")
gr.HTML("<h3><center>Dans cette démo, vous pouvez poser des questions mathématiques et scientifiques à Mathstral. 🧮</center><h3>")
gr.ChatInterface(
generate,
additional_inputs=additional_inputs,
theme = gr.themes.Soft(),
cache_examples=False,
examples=[ [l.strip()] for l in open("exercices.md").readlines()],
chatbot = gr.Chatbot(
latex_delimiters=[
{"left" : "$$", "right": "$$", "display": True },
{"left" : "\\[", "right": "\\]", "display": True },
{"left" : "\\(", "right": "\\)", "display": False },
{"left": "$", "right": "$", "display": False }
]
)
)
demo.queue(max_size=100).launch(debug=True)
: raisonnement mathématiques et scientifique"
),
]
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Mathstral Test</center><h1>")
gr.HTML("<h3><center>Dans cette démo, vous pouvez poser des questions mathématiques et scientifiques à Mathstral. 🧮</center><h3>")
gr.ChatInterface(
generate,
additional_inputs=additional_inputs,
theme = gr.themes.Soft(),
cache_examples=False,
examples=[ [l.strip()] for l in open("exercices.md").readlines()],
chatbot = gr.Chatbot(
latex_delimiters=[
{"left" : "$$", "right": "$$", "display": True },
{"left" : "\\[", "right": "\\]", "display": True },
{"left" : "\\(", "right": "\\)", "display": False },
{"left": "$", "right": "$", "display": False }
]
)
)
demo.queue(max_size=100).launch(debug=True)