tensorgirl commited on
Commit
856414b
·
verified ·
1 Parent(s): 2a5ffea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -12
app.py CHANGED
@@ -1,14 +1,42 @@
1
- from transformers import AutoTokenizer, GemmaForSequenceClassification
 
2
  import gradio as gr
3
- import os
4
- model = GemmaForSequenceClassification.from_pretrained("google/gemma-2b")
5
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
6
- def greet(name):
7
- x = tokenizer(name,return_tensors='pt')
8
- with torch.no_grad():
9
- preds = model(**inputs)
10
-
11
- return preds
12
-
13
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  iface.launch()
 
1
+ import os
2
+ import json
3
  import gradio as gr
4
+ import google.generativeai as genai
5
+
6
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
7
+ genai.configure(api_key=GOOGLE_API_KEY)
8
+
9
+ # Set up the model
10
+ generation_config = {
11
+ "temperature": 0.9,
12
+ "top_p": 1,
13
+ "top_k": 1,
14
+ "max_output_tokens": 2048,
15
+ }
16
+
17
+
18
+ model = genai.GenerativeModel(
19
+ model_name="gemini-pro",
20
+ generation_config=generation_config,
21
+ )
22
+
23
+ task_description = " You need to classify each message you receive among the following categories: 'admiration','amusement','anger','annoyance','approval','caring','confusion','curiosity','desire','disappointment','disapproval','disgust','embarrassment','excitement','fear','gratitude','grief','joy','love','nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'<div>Let the output be a softmax function output giving the probability of message belonging to each category.</div><div>The sum of the probabilities should be 1</div><div>The output must be in JSON format</div>"
24
+
25
+
26
+ def classify_msg(message):
27
+ prompt_parts = [
28
+ task_description,
29
+ f"Message: {message}",
30
+ "Category: ",
31
+ ]
32
+
33
+ response = model.generate_content(prompt_parts)
34
+
35
+ json_response = json.loads(
36
+ response.text[response.text.find("{") : response.text.rfind("}") + 1]
37
+ )
38
+
39
+ return gr.Label(json_response)
40
+
41
+ iface = gr.Interface(fn=classify_msg, inputs="text", outputs="text")
42
  iface.launch()