arittrabag commited on
Commit
766f064
·
verified ·
1 Parent(s): dc3cd8a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ from typing import List
6
+ from enhanced_prompt_builder import EnhancedPromptBuilder
7
+ from feedback_analyzer import FeedbackAnalyzer
8
+ from google import generativeai as genai
9
+ from datetime import datetime
10
+ import json
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+
15
+ # Set your Gemini API key here (or load via env var in production)
16
+ genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
17
+ model = genai.GenerativeModel("gemini-2.5-flash")
18
+
19
+ app = FastAPI()
20
+
21
+ # Add CORS middleware
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"], # Allows all origins
25
+ allow_credentials=True,
26
+ allow_methods=["*"], # Allows all methods
27
+ allow_headers=["*"], # Allows all headers
28
+ )
29
+
30
+ # Initialize enhanced components
31
+ enhanced_builder = EnhancedPromptBuilder()
32
+ feedback_analyzer = FeedbackAnalyzer()
33
+
34
+ class AdRequest(BaseModel):
35
+ ad_text: str
36
+ tone: str
37
+ platforms: List[str]
38
+
39
+ class Feedback(BaseModel):
40
+ ad_text: str
41
+ tone: str
42
+ platforms: List[str]
43
+ rewritten_output: str
44
+ rating: int # 1 to 5
45
+
46
+ @app.post("/run-enhanced-agent")
47
+ def run_enhanced_agent(request: AdRequest):
48
+ """Run the agent with enhanced RAG, KG traversal, and adaptive learning"""
49
+ try:
50
+ # Use enhanced prompt builder
51
+ prompt = enhanced_builder.build_adaptive_prompt(
52
+ request.ad_text,
53
+ request.tone,
54
+ request.platforms
55
+ )
56
+
57
+ # Generate response
58
+ response = model.generate_content(prompt)
59
+
60
+ # Get improvement suggestions
61
+ suggestions = enhanced_builder.get_improvement_suggestions()
62
+
63
+ return {
64
+ "rewritten_ads": response.text,
65
+ "metadata": {
66
+ "used_enhanced_features": True,
67
+ "improvement_suggestions": suggestions[:3] # Top 3 suggestions
68
+ }
69
+ }
70
+ except Exception as e:
71
+ raise HTTPException(status_code=500, detail=str(e))
72
+
73
+ @app.post("/feedback")
74
+ def submit_feedback(feedback: Feedback):
75
+ entry = {
76
+ "timestamp": datetime.now().isoformat(),
77
+ "ad_text": feedback.ad_text,
78
+ "tone": feedback.tone,
79
+ "platforms": feedback.platforms,
80
+ "rewritten_output": feedback.rewritten_output,
81
+ "rating": feedback.rating
82
+ }
83
+
84
+ try:
85
+ with open("feedback_store.json", "r+", encoding="utf-8") as f:
86
+ data = json.load(f)
87
+ data.append(entry)
88
+ f.seek(0)
89
+ json.dump(data, f, indent=2)
90
+ return {"message": "Feedback submitted successfully"}
91
+ except Exception as e:
92
+ raise HTTPException(status_code=500, detail=f"Error storing feedback: {str(e)}")
93
+
94
+ @app.get("/insights")
95
+ def get_insights():
96
+ """Get insights from feedback analysis"""
97
+ try:
98
+ analysis = feedback_analyzer.analyze_patterns()
99
+ trends = feedback_analyzer.get_time_based_trends()
100
+ weights = feedback_analyzer.get_adaptive_weights()
101
+
102
+ return {
103
+ "analysis_summary": {
104
+ "total_feedback": analysis.get("total_feedback", 0),
105
+ "average_rating": round(analysis.get("average_rating", 0), 2),
106
+ "recommendations": analysis.get("recommendations", [])[:5]
107
+ },
108
+ "performance_by_tone": analysis.get("tone_stats", {}),
109
+ "performance_by_platform": analysis.get("platform_stats", {}),
110
+ "winning_combinations": analysis.get("high_performing_patterns", []),
111
+ "needs_improvement": analysis.get("low_performing_patterns", []),
112
+ "adaptive_weights": weights,
113
+ "recent_trends": trends
114
+ }
115
+ except Exception as e:
116
+ raise HTTPException(status_code=500, detail=str(e))
117
+
118
+ @app.get("/graph-insights/{tone}/{platform}")
119
+ def get_graph_insights(tone: str, platform: str):
120
+ """Get knowledge graph insights for a specific tone-platform combination"""
121
+ try:
122
+ from enhanced_knowledge_graph import EnhancedKnowledgeGraph
123
+ kg = EnhancedKnowledgeGraph()
124
+
125
+ recommendations = kg.get_recommendations(tone, platform)
126
+ relationship = kg.explain_relationship(tone, platform)
127
+
128
+ # Find related nodes
129
+ tone_related = kg.traverse_bfs(tone, max_depth=2)
130
+ platform_related = kg.traverse_bfs(platform, max_depth=2)
131
+
132
+ return {
133
+ "tone_platform_analysis": {
134
+ "tone": tone,
135
+ "platform": platform,
136
+ "compatibility_score": recommendations["compatibility_score"],
137
+ "relationship_explanation": relationship,
138
+ "suggestions": recommendations["suggested_elements"],
139
+ "warnings": recommendations["warnings"],
140
+ "recommended_creative_types": recommendations["creative_types"]
141
+ },
142
+ "graph_connections": {
143
+ "tone_connections": list(tone_related.keys()),
144
+ "platform_connections": list(platform_related.keys())
145
+ }
146
+ }
147
+ except Exception as e:
148
+ raise HTTPException(status_code=500, detail=str(e))
149
+