arittrabag commited on
Commit
2b88d9f
·
verified ·
1 Parent(s): 91b91a6

Upload 3 files

Browse files
Files changed (2) hide show
  1. enhanced_knowledge_graph.py +253 -0
  2. enhanced_retriever.py +128 -0
enhanced_knowledge_graph.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Set, Tuple, Optional
2
+ from collections import defaultdict, deque
3
+
4
+ class EnhancedKnowledgeGraph:
5
+ """Enhanced Knowledge Graph with traversal capabilities"""
6
+
7
+ def __init__(self):
8
+ # Node properties
9
+ self.nodes = {
10
+ # Tones
11
+ "fun": {
12
+ "type": "tone",
13
+ "properties": {
14
+ "formality": 0.2,
15
+ "energy": 0.9,
16
+ "creativity": 0.8
17
+ }
18
+ },
19
+ "professional": {
20
+ "type": "tone",
21
+ "properties": {
22
+ "formality": 0.9,
23
+ "energy": 0.5,
24
+ "creativity": 0.3
25
+ }
26
+ },
27
+ "semi-fun": {
28
+ "type": "tone",
29
+ "properties": {
30
+ "formality": 0.5,
31
+ "energy": 0.7,
32
+ "creativity": 0.6
33
+ }
34
+ },
35
+
36
+ # Platforms
37
+ "Meta": {
38
+ "type": "platform",
39
+ "properties": {
40
+ "char_limit": 2200,
41
+ "emoji_friendly": True,
42
+ "hashtag_friendly": True,
43
+ "visual_emphasis": 0.9
44
+ }
45
+ },
46
+ "Google": {
47
+ "type": "platform",
48
+ "properties": {
49
+ "char_limit": 90,
50
+ "emoji_friendly": False,
51
+ "hashtag_friendly": False,
52
+ "visual_emphasis": 0.2
53
+ }
54
+ },
55
+ "LinkedIn": {
56
+ "type": "platform",
57
+ "properties": {
58
+ "char_limit": 3000,
59
+ "emoji_friendly": False,
60
+ "hashtag_friendly": True,
61
+ "visual_emphasis": 0.4
62
+ }
63
+ },
64
+
65
+ # Creative Types
66
+ "awareness": {
67
+ "type": "creative_type",
68
+ "properties": {
69
+ "goal": "brand_visibility",
70
+ "cta_strength": 0.3
71
+ }
72
+ },
73
+ "engagement": {
74
+ "type": "creative_type",
75
+ "properties": {
76
+ "goal": "interaction",
77
+ "cta_strength": 0.7
78
+ }
79
+ },
80
+ "conversion": {
81
+ "type": "creative_type",
82
+ "properties": {
83
+ "goal": "sales",
84
+ "cta_strength": 1.0
85
+ }
86
+ }
87
+ }
88
+
89
+ # Edges (relationships)
90
+ self.edges = defaultdict(list)
91
+ self._build_relationships()
92
+
93
+ def _build_relationships(self):
94
+ """Build graph relationships"""
95
+ # Tone -> Platform compatibility
96
+ self.add_edge("fun", "Meta", "highly_compatible", weight=0.9)
97
+ self.add_edge("fun", "LinkedIn", "moderately_compatible", weight=0.3)
98
+ self.add_edge("fun", "Google", "poorly_compatible", weight=0.1)
99
+
100
+ self.add_edge("professional", "LinkedIn", "highly_compatible", weight=0.95)
101
+ self.add_edge("professional", "Google", "highly_compatible", weight=0.9)
102
+ self.add_edge("professional", "Meta", "moderately_compatible", weight=0.5)
103
+
104
+ self.add_edge("semi-fun", "Meta", "highly_compatible", weight=0.8)
105
+ self.add_edge("semi-fun", "LinkedIn", "highly_compatible", weight=0.7)
106
+ self.add_edge("semi-fun", "Google", "moderately_compatible", weight=0.5)
107
+
108
+ # Tone -> Creative Type
109
+ self.add_edge("fun", "awareness", "suitable_for", weight=0.9)
110
+ self.add_edge("fun", "engagement", "suitable_for", weight=0.95)
111
+ self.add_edge("professional", "conversion", "suitable_for", weight=0.9)
112
+ self.add_edge("semi-fun", "engagement", "suitable_for", weight=0.8)
113
+
114
+ # Platform -> Creative Type preferences
115
+ self.add_edge("Meta", "engagement", "prefers", weight=0.9)
116
+ self.add_edge("LinkedIn", "conversion", "prefers", weight=0.8)
117
+ self.add_edge("Google", "conversion", "prefers", weight=0.95)
118
+
119
+ def add_edge(self, from_node: str, to_node: str, relationship: str, weight: float = 1.0):
120
+ """Add an edge to the graph"""
121
+ self.edges[from_node].append({
122
+ "to": to_node,
123
+ "relationship": relationship,
124
+ "weight": weight
125
+ })
126
+
127
+ def traverse_bfs(self, start_node: str, max_depth: int = 2) -> Dict[str, List[Tuple[str, str, float]]]:
128
+ """Breadth-first traversal to find related nodes"""
129
+ visited = set()
130
+ queue = deque([(start_node, 0)])
131
+ paths = defaultdict(list)
132
+
133
+ while queue:
134
+ current_node, depth = queue.popleft()
135
+
136
+ if current_node in visited or depth > max_depth:
137
+ continue
138
+
139
+ visited.add(current_node)
140
+
141
+ for edge in self.edges.get(current_node, []):
142
+ to_node = edge["to"]
143
+ relationship = edge["relationship"]
144
+ weight = edge["weight"]
145
+
146
+ paths[to_node].append((current_node, relationship, weight))
147
+
148
+ if depth < max_depth:
149
+ queue.append((to_node, depth + 1))
150
+
151
+ return dict(paths)
152
+
153
+ def find_best_path(self, start: str, end: str) -> Optional[List[Tuple[str, str, float]]]:
154
+ """Find the best path between two nodes using weighted edges"""
155
+ # Simple Dijkstra-like approach
156
+ distances = {node: float('inf') for node in self.nodes}
157
+ distances[start] = 0
158
+ previous = {}
159
+ unvisited = set(self.nodes.keys())
160
+
161
+ while unvisited:
162
+ current = min(unvisited, key=lambda x: distances[x])
163
+
164
+ if distances[current] == float('inf'):
165
+ break
166
+
167
+ unvisited.remove(current)
168
+
169
+ for edge in self.edges.get(current, []):
170
+ neighbor = edge["to"]
171
+ weight = 1 - edge["weight"] # Convert to distance (lower is better)
172
+ distance = distances[current] + weight
173
+
174
+ if distance < distances[neighbor]:
175
+ distances[neighbor] = distance
176
+ previous[neighbor] = (current, edge["relationship"], edge["weight"])
177
+
178
+ # Reconstruct path
179
+ if end not in previous:
180
+ return None
181
+
182
+ path = []
183
+ current = end
184
+ while current != start:
185
+ if current not in previous:
186
+ return None
187
+ prev_node, rel, weight = previous[current]
188
+ path.append((prev_node, rel, weight))
189
+ current = prev_node
190
+
191
+ return list(reversed(path))
192
+
193
+ def get_recommendations(self, tone: str, platform: str) -> Dict[str, any]:
194
+ """Get recommendations based on tone and platform"""
195
+ recommendations = {
196
+ "compatibility_score": 0,
197
+ "suggested_elements": [],
198
+ "warnings": [],
199
+ "creative_types": []
200
+ }
201
+
202
+ # Check direct compatibility
203
+ for edge in self.edges.get(tone, []):
204
+ if edge["to"] == platform:
205
+ recommendations["compatibility_score"] = edge["weight"]
206
+ break
207
+
208
+ # Find related creative types
209
+ tone_paths = self.traverse_bfs(tone, max_depth=1)
210
+ platform_paths = self.traverse_bfs(platform, max_depth=1)
211
+
212
+ # Extract creative type recommendations
213
+ for node, paths in tone_paths.items():
214
+ if self.nodes.get(node, {}).get("type") == "creative_type":
215
+ for _, rel, weight in paths:
216
+ if rel == "suitable_for" and weight > 0.7:
217
+ recommendations["creative_types"].append(node)
218
+
219
+ # Platform-specific suggestions
220
+ platform_props = self.nodes.get(platform, {}).get("properties", {})
221
+ tone_props = self.nodes.get(tone, {}).get("properties", {})
222
+
223
+ if platform_props.get("emoji_friendly") and tone_props.get("creativity", 0) > 0.7:
224
+ recommendations["suggested_elements"].append("Use emojis to enhance engagement")
225
+ elif not platform_props.get("emoji_friendly") and tone == "fun":
226
+ recommendations["warnings"].append("Platform doesn't support emojis well - adjust tone")
227
+
228
+ if platform_props.get("char_limit", float('inf')) < 100:
229
+ recommendations["suggested_elements"].append("Keep message extremely concise")
230
+
231
+ return recommendations
232
+
233
+ def explain_relationship(self, node1: str, node2: str) -> str:
234
+ """Explain the relationship between two nodes"""
235
+ # Check direct connection first
236
+ for edge in self.edges.get(node1, []):
237
+ if edge["to"] == node2:
238
+ return f"{node1} is {edge['relationship']} with {node2} (strength: {edge['weight']:.2f})"
239
+
240
+ # If no direct connection, find path
241
+ path = self.find_best_path(node1, node2)
242
+
243
+ if not path:
244
+ return f"No direct relationship found between {node1} and {node2}"
245
+
246
+ explanation = []
247
+ current = node1
248
+ for prev_node, relationship, weight in path:
249
+ # The path reconstruction gives us the path backwards, so we need to handle it correctly
250
+ explanation.append(f"{prev_node} {relationship} {current} (strength: {weight:.2f})")
251
+ current = prev_node
252
+
253
+ return " → ".join(explanation)
enhanced_retriever.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple
2
+ import numpy as np
3
+ from collections import defaultdict
4
+ import re
5
+
6
+ class EnhancedRetriever:
7
+ """Enhanced RAG with semantic similarity scoring"""
8
+
9
+ def __init__(self, guideline_path: str = "tone_guidelines.txt"):
10
+ self.guideline_path = guideline_path
11
+ self.guidelines = self._load_guidelines()
12
+ self.embeddings_cache = {}
13
+
14
+ def _load_guidelines(self) -> Dict[str, List[str]]:
15
+ """Load guidelines from file"""
16
+ guidelines = defaultdict(list)
17
+ current_key = None
18
+
19
+ with open(self.guideline_path, "r", encoding="utf-8") as f:
20
+ for line in f:
21
+ line = line.strip()
22
+ if not line:
23
+ continue
24
+ if ":" in line:
25
+ current_key = line.replace(":", "").strip().lower()
26
+ elif current_key:
27
+ guidelines[current_key].append(line.strip("- ").strip())
28
+
29
+ return dict(guidelines)
30
+
31
+ def _simple_embedding(self, text: str) -> np.ndarray:
32
+ """Create simple word-based embeddings for semantic similarity"""
33
+ # Normalize text
34
+ text = text.lower()
35
+
36
+ # Extract key features
37
+ features = {
38
+ 'length': len(text.split()),
39
+ 'has_emoji': int(bool(re.search(r'[😀-🙏]', text))),
40
+ 'has_exclamation': int('!' in text),
41
+ 'formal_words': sum(1 for word in ['professional', 'value', 'benefits', 'business'] if word in text),
42
+ 'casual_words': sum(1 for word in ['fun', 'playful', 'emoji', 'snappy'] if word in text),
43
+ 'cta_presence': int(any(word in text for word in ['cta', 'button', 'click'])),
44
+ 'hashtag_mention': int('#' in text or 'hashtag' in text),
45
+ }
46
+
47
+ # Convert to vector
48
+ return np.array(list(features.values()), dtype=np.float32)
49
+
50
+ def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
51
+ """Calculate cosine similarity between two vectors"""
52
+ dot_product = np.dot(vec1, vec2)
53
+ norm1 = np.linalg.norm(vec1)
54
+ norm2 = np.linalg.norm(vec2)
55
+
56
+ if norm1 == 0 or norm2 == 0:
57
+ return 0.0
58
+
59
+ return dot_product / (norm1 * norm2)
60
+
61
+ def semantic_search(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
62
+ """Perform semantic search across all guidelines"""
63
+ query_embedding = self._simple_embedding(query)
64
+ results = []
65
+
66
+ for category, items in self.guidelines.items():
67
+ for item in items:
68
+ item_embedding = self._simple_embedding(item)
69
+ similarity = self._cosine_similarity(query_embedding, item_embedding)
70
+ results.append((category, item, similarity))
71
+
72
+ # Sort by similarity score
73
+ results.sort(key=lambda x: x[2], reverse=True)
74
+ return results[:top_k]
75
+
76
+ def retrieve_with_relevance(self, tone: str, platforms: List[str]) -> Dict[str, any]:
77
+ """Enhanced retrieval with relevance scoring"""
78
+ context_query = f"{tone} tone for {' '.join(platforms)} platforms"
79
+ semantic_results = self.semantic_search(context_query)
80
+
81
+ # Structure the response with relevance scores
82
+ response = {
83
+ "direct_matches": {},
84
+ "semantic_matches": [],
85
+ "relevance_scores": {}
86
+ }
87
+
88
+ # Direct matches (existing logic)
89
+ tone_lower = tone.lower()
90
+ if tone_lower in self.guidelines:
91
+ response["direct_matches"][tone] = self.guidelines[tone_lower]
92
+ response["relevance_scores"][tone] = 1.0
93
+
94
+ for platform in platforms:
95
+ p_lower = platform.lower()
96
+ if p_lower in self.guidelines:
97
+ response["direct_matches"][platform] = self.guidelines[p_lower]
98
+ response["relevance_scores"][platform] = 1.0
99
+
100
+ # Add semantic matches
101
+ for category, item, score in semantic_results:
102
+ if category not in response["direct_matches"]:
103
+ response["semantic_matches"].append({
104
+ "category": category,
105
+ "guideline": item,
106
+ "relevance": score
107
+ })
108
+
109
+ return response
110
+
111
+ def format_guidance_with_scores(self, retrieval_result: Dict) -> str:
112
+ """Format retrieval results with relevance scores"""
113
+ output = []
114
+
115
+ # Direct matches
116
+ for key, guidelines in retrieval_result["direct_matches"].items():
117
+ score = retrieval_result["relevance_scores"].get(key, 0)
118
+ output.append(f"\n{key} Guidelines (Relevance: {score:.2f}):")
119
+ for guideline in guidelines:
120
+ output.append(f" - {guideline}")
121
+
122
+ # Semantic matches
123
+ if retrieval_result["semantic_matches"]:
124
+ output.append("\nAdditional Relevant Guidelines:")
125
+ for match in retrieval_result["semantic_matches"][:3]: # Top 3
126
+ output.append(f" - [{match['category']}] {match['guideline']} (Score: {match['relevance']:.2f})")
127
+
128
+ return "\n".join(output)