Krish Patel commited on
Commit
990f77e
·
1 Parent(s): 81f219e

Model upload

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ node_modules/
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # from fastapi import FastAPI
2
+ # # from pydantic import BaseModel
3
+ # # from final import predict_news, get_gemini_analysis
4
+
5
+ # # app = FastAPI()
6
+
7
+ # # class NewsInput(BaseModel):
8
+ # # text: str
9
+
10
+ # # @app.post("/analyze")
11
+ # # async def analyze_news(news: NewsInput):
12
+ # # # Get ML and Knowledge Graph prediction
13
+ # # prediction = predict_news(news.text)
14
+
15
+ # # # Get Gemini analysis
16
+ # # gemini_analysis = get_gemini_analysis(news.text)
17
+
18
+ # # return {
19
+ # # "prediction": prediction,
20
+ # # "detailed_analysis": gemini_analysis
21
+ # # }
22
+
23
+ # # @app.get("/health")
24
+ # # async def health_check():
25
+ # # return {"status": "healthy"}
26
+
27
+ # from fastapi import FastAPI
28
+ # from fastapi.middleware.cors import CORSMiddleware
29
+ # from pydantic import BaseModel
30
+ # from final import predict_news, get_gemini_analysis
31
+
32
+ # app = FastAPI()
33
+
34
+ # # Add CORS middleware
35
+ # app.add_middleware(
36
+ # CORSMiddleware,
37
+ # allow_origins=["http://localhost:5173"], # Your React app's URL
38
+ # allow_credentials=True,
39
+ # allow_methods=["*"],
40
+ # allow_headers=["*"],
41
+ # )
42
+
43
+ # # Rest of your code remains the same
44
+ # class NewsInput(BaseModel):
45
+ # text: str
46
+
47
+ # @app.post("/analyze")
48
+ # async def analyze_news(news: NewsInput):
49
+ # prediction = predict_news(news.text)
50
+ # gemini_analysis = get_gemini_analysis(news.text)
51
+
52
+ # return {
53
+ # "prediction": prediction,
54
+ # "detailed_analysis": gemini_analysis
55
+ # }
56
+
57
+ import streamlit as st
58
+ from final import predict_news, get_gemini_analysis
59
+
60
+ def main():
61
+ st.title("News Fact Checker")
62
+ st.write("Enter news text to analyze its authenticity")
63
+
64
+ # Text input area
65
+ news_text = st.text_area("Enter news text here:", height=200)
66
+
67
+ if st.button("Analyze"):
68
+ if news_text:
69
+ with st.spinner("Analyzing..."):
70
+ # Get predictions and analysis
71
+ prediction = predict_news(news_text)
72
+ gemini_analysis = get_gemini_analysis(news_text)
73
+
74
+ # Display results
75
+ st.header("Analysis Results")
76
+
77
+ # Main prediction with color coding
78
+ prediction_color = "green" if prediction == "REAL" else "red"
79
+ st.markdown(f"### Prediction: <span style='color:{prediction_color}'>{prediction}</span>", unsafe_allow_html=True)
80
+
81
+ # Detailed Gemini Analysis
82
+ st.subheader("Detailed Analysis")
83
+
84
+ # Display structured analysis
85
+ col1, col2 = st.columns(2)
86
+
87
+ with col1:
88
+ st.markdown("#### Content Classification")
89
+ st.write(f"Category: {gemini_analysis['text_classification']['category']}")
90
+ st.write(f"Writing Style: {gemini_analysis['text_classification']['writing_style']}")
91
+ st.write(f"Content Type: {gemini_analysis['text_classification']['content_type']}")
92
+
93
+ with col2:
94
+ st.markdown("#### Sentiment Analysis")
95
+ st.write(f"Primary Emotion: {gemini_analysis['sentiment_analysis']['primary_emotion']}")
96
+ st.write(f"Emotional Intensity: {gemini_analysis['sentiment_analysis']['emotional_intensity']}/10")
97
+ st.write(f"Sensationalism Level: {gemini_analysis['sentiment_analysis']['sensationalism_level']}")
98
+
99
+ # Fact checking section
100
+ st.markdown("#### Fact Checking")
101
+ st.write(f"Evidence Present: {gemini_analysis['fact_checking']['evidence_present']}")
102
+ st.write(f"Fact Check Score: {gemini_analysis['fact_checking']['fact_check_score']}/100")
103
+
104
+ # Verifiable claims
105
+ st.markdown("#### Verifiable Claims")
106
+ for claim in gemini_analysis['fact_checking']['verifiable_claims']:
107
+ st.write(f"- {claim}")
108
+
109
+ else:
110
+ st.warning("Please enter some text to analyze")
111
+
112
+ if __name__ == "__main__":
113
+ main()
final.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import networkx as nx
4
+ import spacy
5
+ import pickle
6
+ import pandas as pd
7
+ import google.generativeai as genai
8
+ import json
9
+
10
+ # Load spaCy for NER
11
+ nlp = spacy.load("en_core_web_sm")
12
+
13
+ # Load the trained ML model
14
+ model_path = "./results/checkpoint-5030" # Replace with the actual path to your model
15
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
16
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
17
+ model.eval()
18
+
19
+ #########################
20
+ def setup_gemini():
21
+ genai.configure(api_key='AIzaSyAQzWpSyWyYCM1G5f-G0ulRCQkXuY7admA')
22
+ model = genai.GenerativeModel('gemini-pro')
23
+ return model
24
+ #########################
25
+
26
+ # Load the knowledge graph
27
+ graph_path = "./models/knowledge_graph.pkl" # Replace with the actual path to your knowledge graph
28
+ with open(graph_path, 'rb') as f:
29
+ graph_data = pickle.load(f)
30
+
31
+ knowledge_graph = nx.DiGraph()
32
+ knowledge_graph.add_nodes_from(graph_data['nodes'].items())
33
+ for u, edges in graph_data['edges'].items():
34
+ for v, data in edges.items():
35
+ knowledge_graph.add_edge(u, v, **data)
36
+
37
+ def predict_with_model(text):
38
+ """Predict whether the news is real or fake using the ML model."""
39
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
40
+ with torch.no_grad():
41
+ outputs = model(**inputs)
42
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
43
+ predicted_label = torch.argmax(probabilities, dim=-1).item()
44
+ return "FAKE" if predicted_label == 1 else "REAL"
45
+
46
+ def update_knowledge_graph(text, is_real):
47
+ """Update the knowledge graph with the new article."""
48
+ entities = extract_entities(text)
49
+ for entity, entity_type in entities:
50
+ if not knowledge_graph.has_node(entity):
51
+ knowledge_graph.add_node(
52
+ entity,
53
+ type=entity_type,
54
+ real_count=1 if is_real else 0,
55
+ fake_count=0 if is_real else 1
56
+ )
57
+ else:
58
+ if is_real:
59
+ knowledge_graph.nodes[entity]['real_count'] += 1
60
+ else:
61
+ knowledge_graph.nodes[entity]['fake_count'] += 1
62
+
63
+ for i, (entity1, _) in enumerate(entities):
64
+ for entity2, _ in entities[i+1:]:
65
+ if not knowledge_graph.has_edge(entity1, entity2):
66
+ knowledge_graph.add_edge(
67
+ entity1,
68
+ entity2,
69
+ weight=1,
70
+ is_real=is_real
71
+ )
72
+ else:
73
+ knowledge_graph[entity1][entity2]['weight'] += 1
74
+
75
+ def extract_entities(text):
76
+ """Extract named entities from text using spaCy."""
77
+ doc = nlp(text)
78
+ entities = [(ent.text, ent.label_) for ent in doc.ents]
79
+ return entities
80
+
81
+ def predict_with_knowledge_graph(text):
82
+ """Predict whether the news is real or fake using the knowledge graph."""
83
+ entities = extract_entities(text)
84
+ real_score = 0
85
+ fake_score = 0
86
+
87
+ for entity, _ in entities:
88
+ if knowledge_graph.has_node(entity):
89
+ real_count = knowledge_graph.nodes[entity].get('real_count', 0)
90
+ fake_count = knowledge_graph.nodes[entity].get('fake_count', 0)
91
+ total = real_count + fake_count
92
+ if total > 0:
93
+ real_score += real_count / total
94
+ fake_score += fake_count / total
95
+
96
+ if real_score > fake_score:
97
+ return "REAL"
98
+ else:
99
+ return "FAKE"
100
+
101
+ def predict_news(text):
102
+ """Predict whether the news is real or fake using both the ML model and the knowledge graph."""
103
+ # Predict with the ML model
104
+ ml_prediction = predict_with_model(text)
105
+ is_real = ml_prediction == "REAL"
106
+
107
+ # Update the knowledge graph
108
+ update_knowledge_graph(text, is_real)
109
+
110
+ # Predict with the knowledge graph
111
+ kg_prediction = predict_with_knowledge_graph(text)
112
+
113
+ # Combine predictions (for simplicity, we use the ML model's prediction here)
114
+ # You can enhance this by combining the scores from both predictions
115
+ return ml_prediction if ml_prediction == kg_prediction else "UNCERTAIN"
116
+
117
+ #########################
118
+ # def analyze_content_gemini(model, text):
119
+ # prompt = f"""Analyze this news text and provide results in the following JSON-like format:
120
+
121
+ # TEXT: {text}
122
+
123
+ # Please provide analysis in these specific sections:
124
+
125
+ # 1. GEMINI ANALYSIS:
126
+ # - Predicted Classification: [Real/Fake]
127
+ # - Confidence Score: [0-100%]
128
+ # - Reasoning: [Key points for classification]
129
+
130
+ # 2. TEXT CLASSIFICATION:
131
+ # - Content category/topic
132
+ # - Writing style: [Formal/Informal/Clickbait]
133
+ # - Target audience
134
+ # - Content type: [news/opinion/editorial]
135
+
136
+ # 3. SENTIMENT ANALYSIS:
137
+ # - Primary emotion
138
+ # - Emotional intensity (1-10)
139
+ # - Sensationalism Level: [High/Medium/Low]
140
+ # - Bias Indicators: [List if any]
141
+ # - Tone: (formal/informal), [Professional/Emotional/Neutral]
142
+ # - Key emotional triggers
143
+
144
+ # 4. ENTITY RECOGNITION:
145
+ # - Source Credibility: [High/Medium/Low]
146
+ # - People mentioned
147
+ # - Organizations
148
+ # - Locations
149
+ # - Dates/Time references
150
+ # - Key numbers/statistics
151
+
152
+ # 5. CONTEXT EXTRACTION:
153
+ # - Main narrative/story
154
+ # - Supporting elements
155
+ # - Key claims
156
+ # - Narrative structure
157
+
158
+ # 6. FACT CHECKING:
159
+ # - Verifiable Claims: [List main claims]
160
+ # - Evidence Present: [Yes/No]
161
+ # - Fact Check Score: [0-100%]
162
+
163
+ # Format the response clearly with distinct sections."""
164
+
165
+ # response = model.generate_content(prompt)
166
+ # return response.text
167
+
168
+ def analyze_content_gemini(model, text):
169
+ prompt = f"""Analyze this news text and return a JSON object with the following structure:
170
+ {{
171
+ "gemini_analysis": {{
172
+ "predicted_classification": "Real or Fake",
173
+ "confidence_score": "0-100",
174
+ "reasoning": ["point1", "point2"]
175
+ }},
176
+ "text_classification": {{
177
+ "category": "",
178
+ "writing_style": "Formal/Informal/Clickbait",
179
+ "target_audience": "",
180
+ "content_type": "news/opinion/editorial"
181
+ }},
182
+ "sentiment_analysis": {{
183
+ "primary_emotion": "",
184
+ "emotional_intensity": "1-10",
185
+ "sensationalism_level": "High/Medium/Low",
186
+ "bias_indicators": ["bias1", "bias2"],
187
+ "tone": {{"formality": "formal/informal", "style": "Professional/Emotional/Neutral"}},
188
+ "emotional_triggers": ["trigger1", "trigger2"]
189
+ }},
190
+ "entity_recognition": {{
191
+ "source_credibility": "High/Medium/Low",
192
+ "people": ["person1", "person2"],
193
+ "organizations": ["org1", "org2"],
194
+ "locations": ["location1", "location2"],
195
+ "dates": ["date1", "date2"],
196
+ "statistics": ["stat1", "stat2"]
197
+ }},
198
+ "context": {{
199
+ "main_narrative": "",
200
+ "supporting_elements": ["element1", "element2"],
201
+ "key_claims": ["claim1", "claim2"],
202
+ "narrative_structure": ""
203
+ }},
204
+ "fact_checking": {{
205
+ "verifiable_claims": ["claim1", "claim2"],
206
+ "evidence_present": "Yes/No",
207
+ "fact_check_score": "0-100"
208
+ }}
209
+ }}
210
+
211
+ Analyze this text and return only the JSON response: {text}"""
212
+
213
+ response = model.generate_content(prompt)
214
+ # return json.loads(response.text)
215
+ # Add error handling and response cleaning
216
+ try:
217
+ # Clean the response text to ensure it's valid JSON
218
+ cleaned_text = response.text.strip()
219
+ if cleaned_text.startswith('```json'):
220
+ cleaned_text = cleaned_text[7:-3] # Remove ```json and ``` markers
221
+ return json.loads(cleaned_text)
222
+ except json.JSONDecodeError:
223
+ # Return a default structured response if JSON parsing fails
224
+ return {
225
+ "gemini_analysis": {
226
+ "predicted_classification": "UNCERTAIN",
227
+ "confidence_score": "50",
228
+ "reasoning": ["Analysis failed to generate valid JSON"]
229
+ }
230
+ }
231
+
232
+
233
+ def clean_gemini_output(text):
234
+ """Remove markdown formatting from Gemini output"""
235
+ text = text.replace('##', '')
236
+ text = text.replace('**', '')
237
+ return text
238
+
239
+ def get_gemini_analysis(text):
240
+ """Get detailed content analysis from Gemini."""
241
+ gemini_model = setup_gemini()
242
+ gemini_analysis = analyze_content_gemini(gemini_model, text)
243
+ # cleaned_analysis = clean_gemini_output(gemini_analysis)
244
+ # return cleaned_analysis
245
+ return gemini_analysis
246
+ #########################
247
+
248
+ def main():
249
+ print("Welcome to the News Classifier!")
250
+ print("Enter your news text below. Type 'Exit' to quit.")
251
+
252
+ while True:
253
+ news_text = input("\nEnter news text: ")
254
+
255
+ if news_text.lower() == 'exit':
256
+ print("Thank you for using the News Classifier!")
257
+ return
258
+
259
+ # First get ML and Knowledge Graph prediction
260
+ prediction = predict_news(news_text)
261
+ print(f"\nML and Knowledge Graph Analysis: {prediction}")
262
+
263
+ # Then get Gemini analysis
264
+ print("\n=== Detailed Gemini Analysis ===")
265
+ gemini_result = get_gemini_analysis(news_text)
266
+ print(gemini_result)
267
+
268
+
269
+ if __name__ == "__main__":
270
+ main()
knowledge_graph_generator.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import networkx as nx
3
+ import spacy
4
+ import pickle
5
+ from datetime import datetime
6
+ import os
7
+
8
+ # Load spaCy for NER
9
+ nlp = spacy.load("en_core_web_sm")
10
+
11
+ class KnowledgeGraphBuilder:
12
+ def __init__(self, model_dir="models"):
13
+ self.model_dir = model_dir
14
+ self.knowledge_graph = nx.DiGraph()
15
+
16
+ def extract_entities(self, text):
17
+ """Extract named entities from text using spaCy"""
18
+ try:
19
+ # Convert to string and handle NaN/None values
20
+ if pd.isna(text) or text is None:
21
+ return []
22
+
23
+ # Convert float or int to string if necessary
24
+ if isinstance(text, (float, int)):
25
+ text = str(text)
26
+
27
+ # Ensure text is a string
28
+ text = str(text).strip()
29
+
30
+ # Skip empty strings
31
+ if not text:
32
+ return []
33
+
34
+ doc = nlp(text)
35
+ entities = [(ent.text, ent.label_) for ent in doc.ents]
36
+ return entities
37
+ except Exception as e:
38
+ print(f"Error processing text: {text}")
39
+ print(f"Error message: {str(e)}")
40
+ return []
41
+
42
+ def update_knowledge_graph(self, text, is_real):
43
+ """Update knowledge graph with entities and their relationships"""
44
+ try:
45
+ entities = self.extract_entities(text)
46
+
47
+ # Skip if no entities were found
48
+ if not entities:
49
+ return
50
+
51
+ # Add nodes and edges to the graph
52
+ for entity, entity_type in entities:
53
+ # Add node if it doesn't exist
54
+ if not self.knowledge_graph.has_node(entity):
55
+ self.knowledge_graph.add_node(
56
+ entity,
57
+ type=entity_type,
58
+ real_count=1 if is_real else 0,
59
+ fake_count=0 if is_real else 1
60
+ )
61
+ else:
62
+ # Update counts
63
+ if is_real:
64
+ self.knowledge_graph.nodes[entity]['real_count'] += 1
65
+ else:
66
+ self.knowledge_graph.nodes[entity]['fake_count'] += 1
67
+
68
+ # Add edges between entities in the same text
69
+ for i, (entity1, _) in enumerate(entities):
70
+ for entity2, _ in entities[i+1:]:
71
+ if not self.knowledge_graph.has_edge(entity1, entity2):
72
+ self.knowledge_graph.add_edge(
73
+ entity1,
74
+ entity2,
75
+ weight=1,
76
+ is_real=is_real
77
+ )
78
+ else:
79
+ self.knowledge_graph[entity1][entity2]['weight'] += 1
80
+ except Exception as e:
81
+ print(f"Error updating knowledge graph: {str(e)}")
82
+
83
+ def save_knowledge_graph(self, filename=None):
84
+ """Save the knowledge graph to a file"""
85
+ if filename is None:
86
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
87
+ filename = os.path.join(self.model_dir, f"knowledge_graph_{timestamp}.pkl")
88
+
89
+ os.makedirs(self.model_dir, exist_ok=True)
90
+
91
+ # Convert the graph to a dictionary format for better serialization
92
+ graph_data = {
93
+ 'nodes': dict(self.knowledge_graph.nodes(data=True)),
94
+ 'edges': {}
95
+ }
96
+
97
+ # Properly format edges with their data
98
+ for u, v, data in self.knowledge_graph.edges(data=True):
99
+ if u not in graph_data['edges']:
100
+ graph_data['edges'][u] = {}
101
+ graph_data['edges'][u][v] = data
102
+
103
+ try:
104
+ with open(filename, 'wb') as f:
105
+ pickle.dump(graph_data, f)
106
+ print(f"Knowledge graph saved to {filename}")
107
+ print(f"Total nodes: {len(graph_data['nodes'])}")
108
+ print(f"Total edges: {sum(len(edges) for edges in graph_data['edges'].values())}")
109
+ return filename
110
+ except Exception as e:
111
+ print(f"Error saving knowledge graph: {str(e)}")
112
+ return None
113
+
114
+ def get_graph_statistics(self):
115
+ """Get basic statistics about the knowledge graph"""
116
+ stats = {
117
+ 'total_nodes': self.knowledge_graph.number_of_nodes(),
118
+ 'total_edges': self.knowledge_graph.number_of_edges(),
119
+ 'entity_types': {},
120
+ 'reliability_scores': {}
121
+ }
122
+
123
+ # Count entity types
124
+ for node, attrs in self.knowledge_graph.nodes(data=True):
125
+ entity_type = attrs.get('type', 'UNKNOWN')
126
+ stats['entity_types'][entity_type] = stats['entity_types'].get(entity_type, 0) + 1
127
+
128
+ # Calculate reliability score
129
+ real_count = attrs.get('real_count', 0)
130
+ fake_count = attrs.get('fake_count', 0)
131
+ total = real_count + fake_count
132
+ if total > 0:
133
+ reliability = real_count / total
134
+ stats['reliability_scores'][node] = reliability
135
+
136
+ return stats
137
+
138
+ def main():
139
+ # Initialize the knowledge graph builder
140
+ builder = KnowledgeGraphBuilder()
141
+
142
+ # Load your dataset
143
+ df = pd.read_csv('./combined.csv') # Replace with your actual data file
144
+
145
+ # Create knowledge graph
146
+ print("Building knowledge graph...")
147
+ total_rows = len(df)
148
+ for idx, row in df.iterrows():
149
+ try:
150
+ builder.update_knowledge_graph(row['text'], row['label'] == 'REAL')
151
+ if (idx + 1) % 100 == 0:
152
+ print(f"Processed {idx + 1}/{total_rows} entries ({(idx + 1)/total_rows*100:.1f}%)...")
153
+ except Exception as e:
154
+ print(f"Error processing row {idx}: {str(e)}")
155
+ continue
156
+
157
+ # Save the knowledge graph
158
+ graph_path = builder.save_knowledge_graph()
159
+
160
+ # Print statistics
161
+ stats = builder.get_graph_statistics()
162
+ print("\nKnowledge Graph Statistics:")
163
+ print(f"Total nodes: {stats['total_nodes']}")
164
+ print(f"Total edges: {stats['total_edges']}")
165
+ print("\nEntity types distribution:")
166
+ for entity_type, count in stats['entity_types'].items():
167
+ print(f"{entity_type}: {count}")
168
+
169
+ if __name__ == "__main__":
170
+ main()
models/knowledge_graph.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f6259a6e81cc6c739d239b3846fc112238e206f65f0999184c86e1539c43ab9
3
+ size 249881241
nlp_trainer.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.model_selection import train_test_split
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ from transformers import Trainer, TrainingArguments
6
+ from torch.utils.data import Dataset
7
+ import torch
8
+ import re
9
+ import string
10
+ import logging
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ def load_dataset(path="./combined.csv"):
14
+ df = pd.read_csv(path, dtype={'text': str, 'label': str}) # Explicitly set dtypes
15
+ df = df.dropna() # Remove any null values
16
+
17
+ # Ensure consistent column names
18
+ if 'news' in df.columns:
19
+ df = df.rename(columns={"news": "text"})
20
+ if 'target' in df.columns:
21
+ df = df.rename(columns={"target": "label"})
22
+
23
+ # Convert labels to integers safely
24
+ label_map = {"real": 0, "fake": 1}
25
+ df['label'] = df['label'].str.lower().map(label_map)
26
+
27
+ # Drop any rows where label mapping failed
28
+ df = df.dropna(subset=['label'])
29
+ df['label'] = df['label'].astype(int)
30
+
31
+ X = df['text'].apply(str).tolist() # Ensure text is string
32
+ y = df['label'].tolist()
33
+
34
+ return train_test_split(X, y, test_size=0.2, random_state=42)
35
+
36
+ class NewsDataset(Dataset):
37
+ def __init__(self, texts, labels, tokenizer, max_len):
38
+ self.texts = texts
39
+ self.labels = labels
40
+ self.tokenizer = tokenizer
41
+ self.max_len = max_len
42
+
43
+ def __len__(self):
44
+ return len(self.texts)
45
+
46
+ def __getitem__(self, idx):
47
+ text = str(self.texts[idx])
48
+ encoding = self.tokenizer(
49
+ text,
50
+ max_length=self.max_len,
51
+ padding='max_length',
52
+ truncation=True,
53
+ return_tensors="pt"
54
+ )
55
+ return {
56
+ 'input_ids': encoding['input_ids'].squeeze(0),
57
+ 'attention_mask': encoding['attention_mask'].squeeze(0),
58
+ 'labels': torch.tensor(int(self.labels[idx]), dtype=torch.long)
59
+ }
60
+
61
+ def train_model(train_texts, train_labels, val_texts, val_labels):
62
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
63
+ model = AutoModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-small', num_labels=2)
64
+
65
+ train_dataset = NewsDataset(train_texts, train_labels, tokenizer, max_len=128)
66
+ val_dataset = NewsDataset(val_texts, val_labels, tokenizer, max_len=128)
67
+
68
+ training_args = TrainingArguments(
69
+ output_dir='./results',
70
+ num_train_epochs=5,
71
+ per_device_train_batch_size=8,
72
+ per_device_eval_batch_size=8,
73
+ warmup_steps=500,
74
+ weight_decay=0.01,
75
+ logging_dir='./logs',
76
+ evaluation_strategy="epoch",
77
+ save_strategy="epoch"
78
+ )
79
+
80
+ trainer = Trainer(
81
+ model=model,
82
+ args=training_args,
83
+ train_dataset=train_dataset,
84
+ eval_dataset=val_dataset
85
+ )
86
+
87
+ trainer.train()
88
+ return tokenizer, model
89
+
90
+ def predict_news(tokenizer, model, news_text):
91
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
+ model.to(device)
93
+ print(device)
94
+ model.eval()
95
+
96
+ encoding = tokenizer(
97
+ str(news_text),
98
+ max_length=128,
99
+ padding='max_length',
100
+ truncation=True,
101
+ return_tensors="pt"
102
+ )
103
+
104
+ input_ids = encoding['input_ids'].to(device)
105
+ attention_mask = encoding['attention_mask'].to(device)
106
+
107
+ with torch.no_grad():
108
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
109
+ prediction = torch.argmax(outputs.logits, dim=1).item()
110
+
111
+ return "Fake" if prediction == 1 else "Real"
112
+
113
+ def main():
114
+ try:
115
+ X_train, X_test, y_train, y_test = load_dataset()
116
+ tokenizer, model = train_model(X_train, y_train, X_test, y_test)
117
+
118
+ while True:
119
+ user_input = input("\nEnter news text (or 'exit' to quit): ")
120
+ if user_input.lower() == 'exit':
121
+ break
122
+ result = predict_news(tokenizer, model, user_input)
123
+ print(f"The news is: {result}")
124
+
125
+ except Exception as e:
126
+ logging.error(f"An error occurred: {str(e)}")
127
+ raise
128
+
129
+ if __name__ == "__main__":
130
+ main()
package-lock.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "complete_nlp_stuff",
3
+ "lockfileVersion": 3,
4
+ "requires": true,
5
+ "packages": {
6
+ "": {
7
+ "dependencies": {
8
+ "axios": "^1.7.9"
9
+ }
10
+ },
11
+ "node_modules/asynckit": {
12
+ "version": "0.4.0",
13
+ "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
14
+ "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==",
15
+ "license": "MIT"
16
+ },
17
+ "node_modules/axios": {
18
+ "version": "1.7.9",
19
+ "resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz",
20
+ "integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==",
21
+ "license": "MIT",
22
+ "dependencies": {
23
+ "follow-redirects": "^1.15.6",
24
+ "form-data": "^4.0.0",
25
+ "proxy-from-env": "^1.1.0"
26
+ }
27
+ },
28
+ "node_modules/combined-stream": {
29
+ "version": "1.0.8",
30
+ "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz",
31
+ "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==",
32
+ "license": "MIT",
33
+ "dependencies": {
34
+ "delayed-stream": "~1.0.0"
35
+ },
36
+ "engines": {
37
+ "node": ">= 0.8"
38
+ }
39
+ },
40
+ "node_modules/delayed-stream": {
41
+ "version": "1.0.0",
42
+ "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
43
+ "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==",
44
+ "license": "MIT",
45
+ "engines": {
46
+ "node": ">=0.4.0"
47
+ }
48
+ },
49
+ "node_modules/follow-redirects": {
50
+ "version": "1.15.9",
51
+ "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.9.tgz",
52
+ "integrity": "sha512-gew4GsXizNgdoRyqmyfMHyAmXsZDk6mHkSxZFCzW9gwlbtOW44CDtYavM+y+72qD/Vq2l550kMF52DT8fOLJqQ==",
53
+ "funding": [
54
+ {
55
+ "type": "individual",
56
+ "url": "https://github.com/sponsors/RubenVerborgh"
57
+ }
58
+ ],
59
+ "license": "MIT",
60
+ "engines": {
61
+ "node": ">=4.0"
62
+ },
63
+ "peerDependenciesMeta": {
64
+ "debug": {
65
+ "optional": true
66
+ }
67
+ }
68
+ },
69
+ "node_modules/form-data": {
70
+ "version": "4.0.1",
71
+ "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.1.tgz",
72
+ "integrity": "sha512-tzN8e4TX8+kkxGPK8D5u0FNmjPUjw3lwC9lSLxxoB/+GtsJG91CO8bSWy73APlgAZzZbXEYZJuxjkHH2w+Ezhw==",
73
+ "license": "MIT",
74
+ "dependencies": {
75
+ "asynckit": "^0.4.0",
76
+ "combined-stream": "^1.0.8",
77
+ "mime-types": "^2.1.12"
78
+ },
79
+ "engines": {
80
+ "node": ">= 6"
81
+ }
82
+ },
83
+ "node_modules/mime-db": {
84
+ "version": "1.52.0",
85
+ "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz",
86
+ "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==",
87
+ "license": "MIT",
88
+ "engines": {
89
+ "node": ">= 0.6"
90
+ }
91
+ },
92
+ "node_modules/mime-types": {
93
+ "version": "2.1.35",
94
+ "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz",
95
+ "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==",
96
+ "license": "MIT",
97
+ "dependencies": {
98
+ "mime-db": "1.52.0"
99
+ },
100
+ "engines": {
101
+ "node": ">= 0.6"
102
+ }
103
+ },
104
+ "node_modules/proxy-from-env": {
105
+ "version": "1.1.0",
106
+ "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz",
107
+ "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==",
108
+ "license": "MIT"
109
+ }
110
+ }
111
+ }
package.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "dependencies": {
3
+ "axios": "^1.7.9"
4
+ }
5
+ }
results/checkpoint-5030/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "microsoft/deberta-v3-small",
3
+ "architectures": [
4
+ "DebertaV2ForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.1,
9
+ "hidden_size": 768,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-07,
13
+ "max_position_embeddings": 512,
14
+ "max_relative_positions": -1,
15
+ "model_type": "deberta-v2",
16
+ "norm_rel_ebd": "layer_norm",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 6,
19
+ "pad_token_id": 0,
20
+ "pooler_dropout": 0,
21
+ "pooler_hidden_act": "gelu",
22
+ "pooler_hidden_size": 768,
23
+ "pos_att_type": [
24
+ "p2c",
25
+ "c2p"
26
+ ],
27
+ "position_biased_input": false,
28
+ "position_buckets": 256,
29
+ "relative_attention": true,
30
+ "share_att_key": true,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.46.2",
33
+ "type_vocab_size": 0,
34
+ "vocab_size": 128100
35
+ }
results/checkpoint-5030/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f34f9b72aa96cb0927c5cfcdad25c0281212e297d61dd14dcacdb68138c40840
3
+ size 567598552
results/checkpoint-5030/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cddba7c9ed0694f75f418657613b8400183c22b1e86f0d5fac90de0153d72e5f
3
+ size 1135260474
results/checkpoint-5030/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d0c9d10259d2c7407ae8f630db471aed45598cb19d4fec8b8a17555906525a5
3
+ size 14244
results/checkpoint-5030/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f0b07a36064ffcbc9c9cdc658bf6076e72b04ada218a099af03a6b74a3518d1
3
+ size 1064
results/checkpoint-5030/trainer_state.json ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 5.0,
5
+ "eval_steps": 500,
6
+ "global_step": 5030,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.4970178926441352,
13
+ "grad_norm": 11.328213691711426,
14
+ "learning_rate": 5e-05,
15
+ "loss": 0.3471,
16
+ "step": 500
17
+ },
18
+ {
19
+ "epoch": 0.9940357852882704,
20
+ "grad_norm": 0.29149460792541504,
21
+ "learning_rate": 4.448123620309051e-05,
22
+ "loss": 0.1462,
23
+ "step": 1000
24
+ },
25
+ {
26
+ "epoch": 1.0,
27
+ "eval_loss": 0.14880910515785217,
28
+ "eval_runtime": 32.5193,
29
+ "eval_samples_per_second": 61.871,
30
+ "eval_steps_per_second": 7.749,
31
+ "step": 1006
32
+ },
33
+ {
34
+ "epoch": 1.4910536779324055,
35
+ "grad_norm": 0.04432953894138336,
36
+ "learning_rate": 3.896247240618102e-05,
37
+ "loss": 0.0738,
38
+ "step": 1500
39
+ },
40
+ {
41
+ "epoch": 1.9880715705765408,
42
+ "grad_norm": 0.004722778219729662,
43
+ "learning_rate": 3.3443708609271526e-05,
44
+ "loss": 0.0599,
45
+ "step": 2000
46
+ },
47
+ {
48
+ "epoch": 2.0,
49
+ "eval_loss": 0.17704755067825317,
50
+ "eval_runtime": 32.4526,
51
+ "eval_samples_per_second": 61.998,
52
+ "eval_steps_per_second": 7.765,
53
+ "step": 2012
54
+ },
55
+ {
56
+ "epoch": 2.485089463220676,
57
+ "grad_norm": 0.0014285552315413952,
58
+ "learning_rate": 2.792494481236203e-05,
59
+ "loss": 0.0176,
60
+ "step": 2500
61
+ },
62
+ {
63
+ "epoch": 2.982107355864811,
64
+ "grad_norm": 0.0008603875176049769,
65
+ "learning_rate": 2.240618101545254e-05,
66
+ "loss": 0.026,
67
+ "step": 3000
68
+ },
69
+ {
70
+ "epoch": 3.0,
71
+ "eval_loss": 0.16322186589241028,
72
+ "eval_runtime": 32.2403,
73
+ "eval_samples_per_second": 62.406,
74
+ "eval_steps_per_second": 7.816,
75
+ "step": 3018
76
+ },
77
+ {
78
+ "epoch": 3.4791252485089466,
79
+ "grad_norm": 0.000587798363994807,
80
+ "learning_rate": 1.688741721854305e-05,
81
+ "loss": 0.0042,
82
+ "step": 3500
83
+ },
84
+ {
85
+ "epoch": 3.9761431411530817,
86
+ "grad_norm": 0.00033068188349716365,
87
+ "learning_rate": 1.1368653421633555e-05,
88
+ "loss": 0.0012,
89
+ "step": 4000
90
+ },
91
+ {
92
+ "epoch": 4.0,
93
+ "eval_loss": 0.20389850437641144,
94
+ "eval_runtime": 33.2829,
95
+ "eval_samples_per_second": 60.452,
96
+ "eval_steps_per_second": 7.571,
97
+ "step": 4024
98
+ },
99
+ {
100
+ "epoch": 4.473161033797217,
101
+ "grad_norm": 0.0048806252889335155,
102
+ "learning_rate": 5.8498896247240626e-06,
103
+ "loss": 0.0013,
104
+ "step": 4500
105
+ },
106
+ {
107
+ "epoch": 4.970178926441352,
108
+ "grad_norm": 0.00042022508569061756,
109
+ "learning_rate": 3.3112582781456954e-07,
110
+ "loss": 0.0006,
111
+ "step": 5000
112
+ },
113
+ {
114
+ "epoch": 5.0,
115
+ "eval_loss": 0.19458653032779694,
116
+ "eval_runtime": 33.1006,
117
+ "eval_samples_per_second": 60.784,
118
+ "eval_steps_per_second": 7.613,
119
+ "step": 5030
120
+ }
121
+ ],
122
+ "logging_steps": 500,
123
+ "max_steps": 5030,
124
+ "num_input_tokens_seen": 0,
125
+ "num_train_epochs": 5,
126
+ "save_steps": 500,
127
+ "stateful_callbacks": {
128
+ "TrainerControl": {
129
+ "args": {
130
+ "should_epoch_stop": false,
131
+ "should_evaluate": false,
132
+ "should_log": false,
133
+ "should_save": true,
134
+ "should_training_stop": true
135
+ },
136
+ "attributes": {}
137
+ }
138
+ },
139
+ "total_flos": 1332007138928640.0,
140
+ "train_batch_size": 8,
141
+ "trial_name": null,
142
+ "trial_params": null
143
+ }
results/checkpoint-5030/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e34c99e352dd9e22706f7f1143f42ff1385e64d6b188ee3ed83ab034094c017
3
+ size 5240