Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,156 +1,149 @@
|
|
1 |
import gradio as gr
|
2 |
import networkx as nx
|
3 |
import matplotlib.pyplot as plt
|
|
|
4 |
import spacy
|
5 |
-
import
|
6 |
-
import numpy as np
|
7 |
from pathlib import Path
|
8 |
|
9 |
-
# Load
|
10 |
nlp = spacy.load("en_core_web_sm")
|
|
|
11 |
|
12 |
-
#
|
13 |
CATEGORIES = {
|
14 |
-
"Main
|
15 |
-
"
|
16 |
-
"
|
17 |
-
"
|
18 |
-
"
|
19 |
}
|
20 |
|
21 |
-
def
|
22 |
-
"""Load
|
23 |
try:
|
24 |
with open("Unit5_OCR.txt", "r", encoding="utf-8") as f:
|
25 |
-
|
26 |
-
return content
|
27 |
except FileNotFoundError:
|
28 |
-
return
|
29 |
|
30 |
-
def
|
31 |
-
"""
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
#
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
entities[ent.text]["count"] += 1
|
46 |
-
|
47 |
-
return entities
|
48 |
|
49 |
-
def
|
50 |
-
"""
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
if ent.text.lower() != term:
|
67 |
-
if ent.text not in related:
|
68 |
-
related[ent.text] = {
|
69 |
-
"type": ent.label_,
|
70 |
-
"count": 1,
|
71 |
-
"relevance": 1.0
|
72 |
-
}
|
73 |
-
else:
|
74 |
-
related[ent.text]["count"] += 1
|
75 |
-
related[ent.text]["relevance"] += 0.5
|
76 |
-
|
77 |
-
index = text.find(term, index + 1)
|
78 |
-
|
79 |
-
return related
|
80 |
|
81 |
def generate_context_map(term):
|
82 |
"""Generate a network visualization for the given term."""
|
83 |
-
if not term.strip():
|
84 |
return None
|
85 |
-
|
86 |
-
# Load
|
87 |
-
content =
|
88 |
-
if
|
89 |
return None
|
90 |
|
91 |
-
#
|
92 |
-
|
|
|
|
|
93 |
|
94 |
-
#
|
95 |
-
|
96 |
|
97 |
-
#
|
98 |
-
G
|
99 |
|
100 |
-
# Add
|
101 |
-
|
102 |
-
|
103 |
-
reverse=True)[:10]
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
# Create visualization
|
112 |
plt.figure(figsize=(12, 12))
|
113 |
plt.clf()
|
114 |
|
115 |
-
# Set
|
116 |
-
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
# Draw nodes
|
119 |
for category, color in CATEGORIES.items():
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
125 |
|
126 |
# Draw edges
|
127 |
-
nx.draw_networkx_edges(G, pos, edge_color='white',
|
128 |
-
width=1, alpha=0.5)
|
129 |
|
130 |
# Add labels
|
131 |
-
|
132 |
-
nx.draw_networkx_labels(G, pos, labels, font_size=8,
|
133 |
-
font_color='white')
|
134 |
-
|
135 |
-
# Set dark background
|
136 |
-
plt.gca().set_facecolor('#1a1a1a')
|
137 |
-
plt.gcf().set_facecolor('#1a1a1a')
|
138 |
|
139 |
# Add title
|
140 |
plt.title(f"Historical Context Map for '{term}'",
|
141 |
-
color='white',
|
|
|
142 |
|
143 |
return plt.gcf()
|
144 |
|
145 |
# Create Gradio interface
|
146 |
iface = gr.Interface(
|
147 |
fn=generate_context_map,
|
148 |
-
inputs=gr.Textbox(
|
149 |
-
|
|
|
|
|
150 |
outputs=gr.Plot(),
|
151 |
title="Historical Context Mapper",
|
152 |
-
description="
|
153 |
-
theme="darkhuggingface",
|
154 |
examples=[
|
155 |
["Civil War"],
|
156 |
["Abraham Lincoln"],
|
|
|
1 |
import gradio as gr
|
2 |
import networkx as nx
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
from transformers import pipeline
|
5 |
import spacy
|
6 |
+
import torch
|
|
|
7 |
from pathlib import Path
|
8 |
|
9 |
+
# Load models
|
10 |
nlp = spacy.load("en_core_web_sm")
|
11 |
+
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
12 |
|
13 |
+
# Define categories and their colors
|
14 |
CATEGORIES = {
|
15 |
+
"Main Theme": "#004d99",
|
16 |
+
"Event": "#006400",
|
17 |
+
"Person": "#8b4513",
|
18 |
+
"Law": "#4b0082",
|
19 |
+
"Concept": "#800000"
|
20 |
}
|
21 |
|
22 |
+
def load_content():
|
23 |
+
"""Load the Unit 5 content."""
|
24 |
try:
|
25 |
with open("Unit5_OCR.txt", "r", encoding="utf-8") as f:
|
26 |
+
return f.read()
|
|
|
27 |
except FileNotFoundError:
|
28 |
+
return None
|
29 |
|
30 |
+
def find_context(term, text, window_size=500):
|
31 |
+
"""Find the relevant context around a term."""
|
32 |
+
term_lower = term.lower()
|
33 |
+
text_lower = text.lower()
|
34 |
|
35 |
+
# Find the term in text
|
36 |
+
index = text_lower.find(term_lower)
|
37 |
+
if index == -1:
|
38 |
+
return ""
|
39 |
+
|
40 |
+
# Get surrounding context
|
41 |
+
start = max(0, index - window_size)
|
42 |
+
end = min(len(text), index + len(term) + window_size)
|
43 |
+
|
44 |
+
return text[start:end]
|
|
|
|
|
|
|
45 |
|
46 |
+
def categorize_term(term, doc):
|
47 |
+
"""Categorize a term based on NER and custom rules."""
|
48 |
+
for ent in doc.ents:
|
49 |
+
if term.lower() in ent.text.lower():
|
50 |
+
if ent.label_ == "PERSON":
|
51 |
+
return "Person"
|
52 |
+
elif ent.label_ == "EVENT" or ent.label_ == "DATE":
|
53 |
+
return "Event"
|
54 |
+
elif ent.label_ == "LAW" or ent.label_ == "ORG":
|
55 |
+
return "Law"
|
56 |
+
|
57 |
+
# Custom categorization for common terms
|
58 |
+
themes = ["manifest destiny", "reconstruction", "civil war", "slavery"]
|
59 |
+
if term.lower() in themes:
|
60 |
+
return "Main Theme"
|
61 |
+
|
62 |
+
return "Concept"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def generate_context_map(term):
|
65 |
"""Generate a network visualization for the given term."""
|
66 |
+
if not term or not term.strip():
|
67 |
return None
|
68 |
+
|
69 |
+
# Load content
|
70 |
+
content = load_content()
|
71 |
+
if not content:
|
72 |
return None
|
73 |
|
74 |
+
# Get context
|
75 |
+
context = find_context(term, content)
|
76 |
+
if not context:
|
77 |
+
return None
|
78 |
|
79 |
+
# Process context
|
80 |
+
doc = nlp(context)
|
81 |
|
82 |
+
# Create graph
|
83 |
+
G = nx.Graph()
|
84 |
|
85 |
+
# Add main term
|
86 |
+
term_category = categorize_term(term, doc)
|
87 |
+
G.add_node(term, category=term_category)
|
|
|
88 |
|
89 |
+
# Find related entities
|
90 |
+
related_entities = []
|
91 |
+
for ent in doc.ents:
|
92 |
+
if ent.text.lower() != term.lower():
|
93 |
+
related_entities.append({
|
94 |
+
'text': ent.text,
|
95 |
+
'category': categorize_term(ent.text, doc)
|
96 |
+
})
|
97 |
+
|
98 |
+
# Add top related entities (limit to 8)
|
99 |
+
for entity in related_entities[:8]:
|
100 |
+
G.add_node(entity['text'], category=entity['category'])
|
101 |
+
G.add_edge(term, entity['text'])
|
102 |
|
103 |
# Create visualization
|
104 |
plt.figure(figsize=(12, 12))
|
105 |
plt.clf()
|
106 |
|
107 |
+
# Set dark background
|
108 |
+
plt.gca().set_facecolor('#1a1a1a')
|
109 |
+
plt.gcf().set_facecolor('#1a1a1a')
|
110 |
+
|
111 |
+
# Create layout
|
112 |
+
pos = nx.spring_layout(G, k=1)
|
113 |
|
114 |
+
# Draw nodes for each category
|
115 |
for category, color in CATEGORIES.items():
|
116 |
+
node_list = [node for node, attr in G.nodes(data=True)
|
117 |
+
if attr.get('category') == category]
|
118 |
+
if node_list:
|
119 |
+
nx.draw_networkx_nodes(G, pos,
|
120 |
+
nodelist=node_list,
|
121 |
+
node_color=color,
|
122 |
+
node_size=2000)
|
123 |
|
124 |
# Draw edges
|
125 |
+
nx.draw_networkx_edges(G, pos, edge_color='white', width=1)
|
|
|
126 |
|
127 |
# Add labels
|
128 |
+
nx.draw_networkx_labels(G, pos, font_size=8, font_color='white')
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
# Add title
|
131 |
plt.title(f"Historical Context Map for '{term}'",
|
132 |
+
color='white',
|
133 |
+
pad=20)
|
134 |
|
135 |
return plt.gcf()
|
136 |
|
137 |
# Create Gradio interface
|
138 |
iface = gr.Interface(
|
139 |
fn=generate_context_map,
|
140 |
+
inputs=gr.Textbox(
|
141 |
+
label="Enter a historical term from Unit 5",
|
142 |
+
placeholder="e.g., Civil War, Abraham Lincoln, Reconstruction"
|
143 |
+
),
|
144 |
outputs=gr.Plot(),
|
145 |
title="Historical Context Mapper",
|
146 |
+
description="Enter a term from Unit 5 (1844-1877) to see its historical context and connections.",
|
|
|
147 |
examples=[
|
148 |
["Civil War"],
|
149 |
["Abraham Lincoln"],
|