hardik90 commited on
Commit
4f6a1fe
·
verified ·
1 Parent(s): a659f37

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import InferenceClient
3
+
4
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
5
+
6
+ def format_prompt(message, history):
7
+ prompt = "<s>"
8
+ for user_prompt, bot_response in history:
9
+ prompt += f"[INST] {user_prompt} [/INST]"
10
+ prompt += f" {bot_response}</s> "
11
+ prompt += f"[INST] {message} [/INST]"
12
+ return prompt
13
+
14
+ def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
15
+ temperature = float(temperature)
16
+ if temperature < 1e-2:
17
+ temperature = 1e-2
18
+ top_p = float(top_p)
19
+
20
+ generate_kwargs = dict(
21
+ temperature=temperature,
22
+ max_new_tokens=max_new_tokens,
23
+ top_p=top_p,
24
+ repetition_penalty=repetition_penalty,
25
+ do_sample=True,
26
+ seed=42,
27
+ )
28
+
29
+ formatted_prompt = format_prompt(prompt, history)
30
+ response = client.text_generation(formatted_prompt, **generate_kwargs)
31
+ return response.choices[0].text
32
+
33
+ def main():
34
+ st.title("Mistral 7B")
35
+ prompt = st.text_input("User Input:", "")
36
+ history = [] # You need to manage the conversation history here
37
+
38
+ temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05)
39
+ max_new_tokens = st.slider("Max new tokens", 0, 1048, 256, step=64)
40
+ top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.90, step=0.05)
41
+ repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.2, step=0.05)
42
+
43
+ if st.button("Generate"):
44
+ output = generate(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty)
45
+ st.text("Bot Output:")
46
+ st.write(output)
47
+
48
+ if __name__ == "__main__":
49
+ main()