karthik18AI commited on
Commit
09777ea
·
verified ·
1 Parent(s): b7a9a38

update mistral

Browse files
Files changed (1) hide show
  1. app.py +55 -38
app.py CHANGED
@@ -1,41 +1,58 @@
1
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
-
3
- model_name = "Salesforce/codet5-base"
4
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
5
- tokenizer = AutoTokenizer.from_pretrained(model_name)
6
-
7
- # More complex C# code snippet
8
- code_snippet = """
9
- // This class demonstrates a simple calculator program
10
- public class Calculator {
11
- // Adds two integers
12
- public int Add(int a, int b) {
13
- return a + b;
14
- }
15
-
16
- // Subtracts second integer from first
17
- public int Subtract(int a, int b) {
18
- return a - b;
19
- }
20
-
21
- // Multiplies two integers
22
- public int Multiply(int a, int b) {
23
- return a * b;
24
- }
25
-
26
- // Divides first integer by second
27
- // Throws DivideByZeroException if b is zero
28
- public int Divide(int a, int b) {
29
- if (b == 0) {
30
- throw new DivideByZeroException("Division by zero is not allowed.");
31
- }
32
- return a / b;
33
- }
34
  }
35
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- inputs = tokenizer(code_snippet, return_tensors="pt")
38
- outputs = model.generate(**inputs, max_new_tokens=100)
39
- review = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
40
 
41
- print("Code Review:", review)
 
1
+ from huggingface_hub import InferenceClient
2
+ import gradio as gr
3
+
4
+ css = '''
5
+ .gradio-container{max-width: 1000px !important}
6
+ h1{text-align:center}
7
+ footer {
8
+ visibility: hidden
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  }
10
+ '''
11
+
12
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
13
+ def format_prompt(message, history, system_prompt=None):
14
+ prompt = "<s>"
15
+ for user_prompt, bot_response in history:
16
+ prompt += f"[INST] {user_prompt} [/INST]"
17
+ prompt += f" {bot_response}</s> "
18
+ if system_prompt:
19
+ prompt += f"[SYS] {system_prompt} [/SYS]"
20
+ prompt += f"[INST] {message} [/INST]"
21
+ return prompt
22
+ #Generate
23
+ def generate(
24
+ prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
25
+ ):
26
+ temperature = float(temperature)
27
+ if temperature < 1e-2:
28
+ temperature = 1e-2
29
+ top_p = float(top_p)
30
+
31
+ generate_kwargs = dict(
32
+ temperature=temperature,
33
+ max_new_tokens=max_new_tokens,
34
+ top_p=top_p,
35
+ repetition_penalty=repetition_penalty,
36
+ do_sample=True,
37
+ seed=42,
38
+ )
39
+
40
+ formatted_prompt = format_prompt(prompt, history, system_prompt)
41
+
42
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
43
+ output = ""
44
+
45
+ for response in stream:
46
+ output += response.token.text
47
+ yield output
48
+ return output
49
+
50
 
51
+ demo = gr.ChatInterface(
52
+ fn=generate,
53
+ css=css,
54
+ title="",
55
+ theme="bethecloud/storj_theme"
56
+ )
57
 
58
+ demo.queue().launch(show_api=False)