tensorgirl commited on
Commit
3760948
·
verified ·
1 Parent(s): 3dd590c

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +50 -0
main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import os
3
+ import json
4
+ import google.generativeai as genai
5
+
6
+ app = FastAPI()
7
+
8
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
9
+ genai.configure(api_key=GOOGLE_API_KEY)
10
+
11
+ # Set up the model
12
+ generation_config = {
13
+ "temperature": 0.9,
14
+ "top_p": 1,
15
+ "top_k": 1,
16
+ "max_output_tokens": 2048,
17
+ }
18
+
19
+
20
+ model = genai.GenerativeModel(
21
+ model_name="gemini-pro",
22
+ generation_config=generation_config,
23
+ )
24
+
25
+ task_description = " You need to classify each message you receive among the following categories: 'admiration','amusement','anger','annoyance','approval','caring','confusion','curiosity','desire','disappointment','disapproval','disgust','embarrassment','excitement','fear','gratitude','grief','joy','love','nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'<div>The output must be in JSON format</div>"
26
+
27
+
28
+ def classify_msg(message):
29
+ prompt_parts = [
30
+ task_description,
31
+ f"Message: {message}",
32
+ "Category: ",
33
+ ]
34
+
35
+ response = model.generate_content(prompt_parts)
36
+
37
+ json_response = json.loads(
38
+ response.text[response.text.find("{") : response.text.rfind("}") + 1]
39
+ )
40
+
41
+ return gr.Label(json_response['category'])
42
+
43
+
44
+ @app.get("/")
45
+ async def root():
46
+ return {"Text Emotion Classification":"Version 1.5 'Text'"}
47
+
48
+ @app.post("/classify/")
49
+ def read_user(text: str):
50
+ return classify_msg(text)