acecalisto3 commited on
Commit
0d4009b
·
verified ·
1 Parent(s): 3ff0de1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py CHANGED
@@ -3,6 +3,50 @@ from transformers import pipeline
3
  import logging
4
  import torch
5
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Logging Setup
8
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
3
  import logging
4
  import torch
5
  import numpy as np
6
+ from huggingface_hub import InferenceClient
7
+ import gradio as gr
8
+
9
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
10
+
11
+ def format_prompt(message, history):
12
+ prompt = "[Instuction]:\n"
13
+ for user_prompt, bot_response in history:
14
+ prompt += f"\n{user_prompt}: {bot_response}"
15
+
16
+ prompt += f"\n{message}<|endoftext|>\n"
17
+ return prompt
18
+
19
+ def generate(input_text, history, system_prompt, temperature, max_length, top_p, repetition_penalty):
20
+ input_text = input_text.strip()
21
+ history = history[:-1] if input_text == '[continue conversation]' else history
22
+ history = list(filter(lambda x: x[0] != '', history))
23
+
24
+ system_prompt = system_prompt.replace("[INST]", "[Instruction]").strip() + "\n"
25
+ prompt = format_prompt(input_text, history)
26
+
27
+ result = client.generate(
28
+ prompt,
29
+ max_length=int(max_length),
30
+ temperature=float(temperature),
31
+ top_p=float(top_p),
32
+ repetition_penalty=float(repetition_penalty),
33
+ num_return_sequences=1,
34
+ do_sample=True
35
+ )[0]['generated_text'].strip()
36
+
37
+ return {"history": history + [(input_text, result)], "result": result}
38
+
39
+ iface = gr.Interface(fn=generate,
40
+ inputs=gr.Inputs(text="input_text",
41
+ textarea="system_prompt",
42
+ sliders={"temperature": (0.0, 1.0, 0.1),
43
+ "max_length": (20, 256, 1),
44
+ "top_p": (0.1, 1.0, 0.1),
45
+ "repetition_penalty": (1.0, 2.0, 0.1)}),
46
+ outputs="markdown",
47
+ interpretation="notext",
48
+ examples={"My first question": ["input_text": "How old are you?", "system_prompt": "", "temperature": 0.5, "max_length": 50, "top_p": 0.5, "repetition_penalty": 1.1}},
49
+ allow_flagging="never")
50
 
51
  # Logging Setup
52
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')