ambrosfitz commited on
Commit
1ef8711
·
verified ·
1 Parent(s): 43524e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -104
app.py CHANGED
@@ -1,156 +1,149 @@
1
  import gradio as gr
2
  import networkx as nx
3
  import matplotlib.pyplot as plt
 
4
  import spacy
5
- import pandas as pd
6
- import numpy as np
7
  from pathlib import Path
8
 
9
- # Load SpaCy model
10
  nlp = spacy.load("en_core_web_sm")
 
11
 
12
- # Categories and their colors
13
  CATEGORIES = {
14
- "Main Themes": "#004d99",
15
- "Events": "#006400",
16
- "People": "#8b4513",
17
- "Laws/Policies": "#4b0082",
18
- "Concepts": "#800000"
19
  }
20
 
21
- def load_historical_data():
22
- """Load and process the Unit 5 text data."""
23
  try:
24
  with open("Unit5_OCR.txt", "r", encoding="utf-8") as f:
25
- content = f.read()
26
- return content
27
  except FileNotFoundError:
28
- return "Historical data file not found."
29
 
30
- def extract_entities(text):
31
- """Extract named entities and important terms from text."""
32
- doc = nlp(text)
33
- entities = {}
34
 
35
- # Extract named entities
36
- for ent in doc.ents:
37
- if ent.label_ in ["PERSON", "EVENT", "DATE", "LAW", "ORG"]:
38
- if ent.text not in entities:
39
- entities[ent.text] = {
40
- "type": ent.label_,
41
- "count": 1,
42
- "context": []
43
- }
44
- else:
45
- entities[ent.text]["count"] += 1
46
-
47
- return entities
48
 
49
- def find_related_terms(term, text, window_size=100):
50
- """Find terms that appear near the search term."""
51
- term = term.lower()
52
- text = text.lower()
53
- related = {}
54
-
55
- # Find all occurrences of the term
56
- index = text.find(term)
57
- while index != -1:
58
- # Get surrounding context
59
- start = max(0, index - window_size)
60
- end = min(len(text), index + len(term) + window_size)
61
- context = text[start:end]
62
-
63
- # Process context to find other entities
64
- doc = nlp(context)
65
- for ent in doc.ents:
66
- if ent.text.lower() != term:
67
- if ent.text not in related:
68
- related[ent.text] = {
69
- "type": ent.label_,
70
- "count": 1,
71
- "relevance": 1.0
72
- }
73
- else:
74
- related[ent.text]["count"] += 1
75
- related[ent.text]["relevance"] += 0.5
76
-
77
- index = text.find(term, index + 1)
78
-
79
- return related
80
 
81
  def generate_context_map(term):
82
  """Generate a network visualization for the given term."""
83
- if not term.strip():
84
  return None
85
-
86
- # Load historical data
87
- content = load_historical_data()
88
- if content == "Historical data file not found.":
89
  return None
90
 
91
- # Create network graph
92
- G = nx.Graph()
 
 
93
 
94
- # Find related terms
95
- related_items = find_related_terms(term, content)
96
 
97
- # Add central node
98
- G.add_node(term, category="Main Themes")
99
 
100
- # Add related nodes (limit to top 10 by relevance)
101
- sorted_items = sorted(related_items.items(),
102
- key=lambda x: x[1]["relevance"],
103
- reverse=True)[:10]
104
 
105
- for item_name, item_data in sorted_items:
106
- G.add_node(item_name, category=item_data["type"])
107
- G.add_edge(term, item_name,
108
- weight=item_data["relevance"],
109
- length=2.0/item_data["relevance"])
 
 
 
 
 
 
 
 
110
 
111
  # Create visualization
112
  plt.figure(figsize=(12, 12))
113
  plt.clf()
114
 
115
- # Set up the layout
116
- pos = nx.spring_layout(G, k=1, iterations=50)
 
 
 
 
117
 
118
- # Draw nodes
119
  for category, color in CATEGORIES.items():
120
- nodes = [node for node, attr in G.nodes(data=True)
121
- if attr.get("category", "") == category]
122
- nx.draw_networkx_nodes(G, pos, nodelist=nodes,
123
- node_color=color,
124
- node_size=2000)
 
 
125
 
126
  # Draw edges
127
- nx.draw_networkx_edges(G, pos, edge_color='white',
128
- width=1, alpha=0.5)
129
 
130
  # Add labels
131
- labels = {node: node for node in G.nodes()}
132
- nx.draw_networkx_labels(G, pos, labels, font_size=8,
133
- font_color='white')
134
-
135
- # Set dark background
136
- plt.gca().set_facecolor('#1a1a1a')
137
- plt.gcf().set_facecolor('#1a1a1a')
138
 
139
  # Add title
140
  plt.title(f"Historical Context Map for '{term}'",
141
- color='white', pad=20)
 
142
 
143
  return plt.gcf()
144
 
145
  # Create Gradio interface
146
  iface = gr.Interface(
147
  fn=generate_context_map,
148
- inputs=gr.Textbox(label="Enter a historical term from Unit 5",
149
- placeholder="e.g., Civil War, Abraham Lincoln, Reconstruction"),
 
 
150
  outputs=gr.Plot(),
151
  title="Historical Context Mapper",
152
- description="This tool generates a network visualization showing the historical context and connections for terms from Unit 5 (1844-1877).",
153
- theme="darkhuggingface",
154
  examples=[
155
  ["Civil War"],
156
  ["Abraham Lincoln"],
 
1
  import gradio as gr
2
  import networkx as nx
3
  import matplotlib.pyplot as plt
4
+ from transformers import pipeline
5
  import spacy
6
+ import torch
 
7
  from pathlib import Path
8
 
9
+ # Load models
10
  nlp = spacy.load("en_core_web_sm")
11
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
12
 
13
+ # Define categories and their colors
14
  CATEGORIES = {
15
+ "Main Theme": "#004d99",
16
+ "Event": "#006400",
17
+ "Person": "#8b4513",
18
+ "Law": "#4b0082",
19
+ "Concept": "#800000"
20
  }
21
 
22
+ def load_content():
23
+ """Load the Unit 5 content."""
24
  try:
25
  with open("Unit5_OCR.txt", "r", encoding="utf-8") as f:
26
+ return f.read()
 
27
  except FileNotFoundError:
28
+ return None
29
 
30
+ def find_context(term, text, window_size=500):
31
+ """Find the relevant context around a term."""
32
+ term_lower = term.lower()
33
+ text_lower = text.lower()
34
 
35
+ # Find the term in text
36
+ index = text_lower.find(term_lower)
37
+ if index == -1:
38
+ return ""
39
+
40
+ # Get surrounding context
41
+ start = max(0, index - window_size)
42
+ end = min(len(text), index + len(term) + window_size)
43
+
44
+ return text[start:end]
 
 
 
45
 
46
+ def categorize_term(term, doc):
47
+ """Categorize a term based on NER and custom rules."""
48
+ for ent in doc.ents:
49
+ if term.lower() in ent.text.lower():
50
+ if ent.label_ == "PERSON":
51
+ return "Person"
52
+ elif ent.label_ == "EVENT" or ent.label_ == "DATE":
53
+ return "Event"
54
+ elif ent.label_ == "LAW" or ent.label_ == "ORG":
55
+ return "Law"
56
+
57
+ # Custom categorization for common terms
58
+ themes = ["manifest destiny", "reconstruction", "civil war", "slavery"]
59
+ if term.lower() in themes:
60
+ return "Main Theme"
61
+
62
+ return "Concept"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def generate_context_map(term):
65
  """Generate a network visualization for the given term."""
66
+ if not term or not term.strip():
67
  return None
68
+
69
+ # Load content
70
+ content = load_content()
71
+ if not content:
72
  return None
73
 
74
+ # Get context
75
+ context = find_context(term, content)
76
+ if not context:
77
+ return None
78
 
79
+ # Process context
80
+ doc = nlp(context)
81
 
82
+ # Create graph
83
+ G = nx.Graph()
84
 
85
+ # Add main term
86
+ term_category = categorize_term(term, doc)
87
+ G.add_node(term, category=term_category)
 
88
 
89
+ # Find related entities
90
+ related_entities = []
91
+ for ent in doc.ents:
92
+ if ent.text.lower() != term.lower():
93
+ related_entities.append({
94
+ 'text': ent.text,
95
+ 'category': categorize_term(ent.text, doc)
96
+ })
97
+
98
+ # Add top related entities (limit to 8)
99
+ for entity in related_entities[:8]:
100
+ G.add_node(entity['text'], category=entity['category'])
101
+ G.add_edge(term, entity['text'])
102
 
103
  # Create visualization
104
  plt.figure(figsize=(12, 12))
105
  plt.clf()
106
 
107
+ # Set dark background
108
+ plt.gca().set_facecolor('#1a1a1a')
109
+ plt.gcf().set_facecolor('#1a1a1a')
110
+
111
+ # Create layout
112
+ pos = nx.spring_layout(G, k=1)
113
 
114
+ # Draw nodes for each category
115
  for category, color in CATEGORIES.items():
116
+ node_list = [node for node, attr in G.nodes(data=True)
117
+ if attr.get('category') == category]
118
+ if node_list:
119
+ nx.draw_networkx_nodes(G, pos,
120
+ nodelist=node_list,
121
+ node_color=color,
122
+ node_size=2000)
123
 
124
  # Draw edges
125
+ nx.draw_networkx_edges(G, pos, edge_color='white', width=1)
 
126
 
127
  # Add labels
128
+ nx.draw_networkx_labels(G, pos, font_size=8, font_color='white')
 
 
 
 
 
 
129
 
130
  # Add title
131
  plt.title(f"Historical Context Map for '{term}'",
132
+ color='white',
133
+ pad=20)
134
 
135
  return plt.gcf()
136
 
137
  # Create Gradio interface
138
  iface = gr.Interface(
139
  fn=generate_context_map,
140
+ inputs=gr.Textbox(
141
+ label="Enter a historical term from Unit 5",
142
+ placeholder="e.g., Civil War, Abraham Lincoln, Reconstruction"
143
+ ),
144
  outputs=gr.Plot(),
145
  title="Historical Context Mapper",
146
+ description="Enter a term from Unit 5 (1844-1877) to see its historical context and connections.",
 
147
  examples=[
148
  ["Civil War"],
149
  ["Abraham Lincoln"],