Krish Patel
commited on
Commit
·
990f77e
1
Parent(s):
81f219e
Model upload
Browse files- .gitignore +2 -0
- app.py +113 -0
- final.py +270 -0
- knowledge_graph_generator.py +170 -0
- models/knowledge_graph.pkl +3 -0
- nlp_trainer.py +130 -0
- package-lock.json +111 -0
- package.json +5 -0
- results/checkpoint-5030/config.json +35 -0
- results/checkpoint-5030/model.safetensors +3 -0
- results/checkpoint-5030/optimizer.pt +3 -0
- results/checkpoint-5030/rng_state.pth +3 -0
- results/checkpoint-5030/scheduler.pt +3 -0
- results/checkpoint-5030/trainer_state.json +143 -0
- results/checkpoint-5030/training_args.bin +3 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
node_modules/
|
app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # from fastapi import FastAPI
|
2 |
+
# # from pydantic import BaseModel
|
3 |
+
# # from final import predict_news, get_gemini_analysis
|
4 |
+
|
5 |
+
# # app = FastAPI()
|
6 |
+
|
7 |
+
# # class NewsInput(BaseModel):
|
8 |
+
# # text: str
|
9 |
+
|
10 |
+
# # @app.post("/analyze")
|
11 |
+
# # async def analyze_news(news: NewsInput):
|
12 |
+
# # # Get ML and Knowledge Graph prediction
|
13 |
+
# # prediction = predict_news(news.text)
|
14 |
+
|
15 |
+
# # # Get Gemini analysis
|
16 |
+
# # gemini_analysis = get_gemini_analysis(news.text)
|
17 |
+
|
18 |
+
# # return {
|
19 |
+
# # "prediction": prediction,
|
20 |
+
# # "detailed_analysis": gemini_analysis
|
21 |
+
# # }
|
22 |
+
|
23 |
+
# # @app.get("/health")
|
24 |
+
# # async def health_check():
|
25 |
+
# # return {"status": "healthy"}
|
26 |
+
|
27 |
+
# from fastapi import FastAPI
|
28 |
+
# from fastapi.middleware.cors import CORSMiddleware
|
29 |
+
# from pydantic import BaseModel
|
30 |
+
# from final import predict_news, get_gemini_analysis
|
31 |
+
|
32 |
+
# app = FastAPI()
|
33 |
+
|
34 |
+
# # Add CORS middleware
|
35 |
+
# app.add_middleware(
|
36 |
+
# CORSMiddleware,
|
37 |
+
# allow_origins=["http://localhost:5173"], # Your React app's URL
|
38 |
+
# allow_credentials=True,
|
39 |
+
# allow_methods=["*"],
|
40 |
+
# allow_headers=["*"],
|
41 |
+
# )
|
42 |
+
|
43 |
+
# # Rest of your code remains the same
|
44 |
+
# class NewsInput(BaseModel):
|
45 |
+
# text: str
|
46 |
+
|
47 |
+
# @app.post("/analyze")
|
48 |
+
# async def analyze_news(news: NewsInput):
|
49 |
+
# prediction = predict_news(news.text)
|
50 |
+
# gemini_analysis = get_gemini_analysis(news.text)
|
51 |
+
|
52 |
+
# return {
|
53 |
+
# "prediction": prediction,
|
54 |
+
# "detailed_analysis": gemini_analysis
|
55 |
+
# }
|
56 |
+
|
57 |
+
import streamlit as st
|
58 |
+
from final import predict_news, get_gemini_analysis
|
59 |
+
|
60 |
+
def main():
|
61 |
+
st.title("News Fact Checker")
|
62 |
+
st.write("Enter news text to analyze its authenticity")
|
63 |
+
|
64 |
+
# Text input area
|
65 |
+
news_text = st.text_area("Enter news text here:", height=200)
|
66 |
+
|
67 |
+
if st.button("Analyze"):
|
68 |
+
if news_text:
|
69 |
+
with st.spinner("Analyzing..."):
|
70 |
+
# Get predictions and analysis
|
71 |
+
prediction = predict_news(news_text)
|
72 |
+
gemini_analysis = get_gemini_analysis(news_text)
|
73 |
+
|
74 |
+
# Display results
|
75 |
+
st.header("Analysis Results")
|
76 |
+
|
77 |
+
# Main prediction with color coding
|
78 |
+
prediction_color = "green" if prediction == "REAL" else "red"
|
79 |
+
st.markdown(f"### Prediction: <span style='color:{prediction_color}'>{prediction}</span>", unsafe_allow_html=True)
|
80 |
+
|
81 |
+
# Detailed Gemini Analysis
|
82 |
+
st.subheader("Detailed Analysis")
|
83 |
+
|
84 |
+
# Display structured analysis
|
85 |
+
col1, col2 = st.columns(2)
|
86 |
+
|
87 |
+
with col1:
|
88 |
+
st.markdown("#### Content Classification")
|
89 |
+
st.write(f"Category: {gemini_analysis['text_classification']['category']}")
|
90 |
+
st.write(f"Writing Style: {gemini_analysis['text_classification']['writing_style']}")
|
91 |
+
st.write(f"Content Type: {gemini_analysis['text_classification']['content_type']}")
|
92 |
+
|
93 |
+
with col2:
|
94 |
+
st.markdown("#### Sentiment Analysis")
|
95 |
+
st.write(f"Primary Emotion: {gemini_analysis['sentiment_analysis']['primary_emotion']}")
|
96 |
+
st.write(f"Emotional Intensity: {gemini_analysis['sentiment_analysis']['emotional_intensity']}/10")
|
97 |
+
st.write(f"Sensationalism Level: {gemini_analysis['sentiment_analysis']['sensationalism_level']}")
|
98 |
+
|
99 |
+
# Fact checking section
|
100 |
+
st.markdown("#### Fact Checking")
|
101 |
+
st.write(f"Evidence Present: {gemini_analysis['fact_checking']['evidence_present']}")
|
102 |
+
st.write(f"Fact Check Score: {gemini_analysis['fact_checking']['fact_check_score']}/100")
|
103 |
+
|
104 |
+
# Verifiable claims
|
105 |
+
st.markdown("#### Verifiable Claims")
|
106 |
+
for claim in gemini_analysis['fact_checking']['verifiable_claims']:
|
107 |
+
st.write(f"- {claim}")
|
108 |
+
|
109 |
+
else:
|
110 |
+
st.warning("Please enter some text to analyze")
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
main()
|
final.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
import networkx as nx
|
4 |
+
import spacy
|
5 |
+
import pickle
|
6 |
+
import pandas as pd
|
7 |
+
import google.generativeai as genai
|
8 |
+
import json
|
9 |
+
|
10 |
+
# Load spaCy for NER
|
11 |
+
nlp = spacy.load("en_core_web_sm")
|
12 |
+
|
13 |
+
# Load the trained ML model
|
14 |
+
model_path = "./results/checkpoint-5030" # Replace with the actual path to your model
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
|
16 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
17 |
+
model.eval()
|
18 |
+
|
19 |
+
#########################
|
20 |
+
def setup_gemini():
|
21 |
+
genai.configure(api_key='AIzaSyAQzWpSyWyYCM1G5f-G0ulRCQkXuY7admA')
|
22 |
+
model = genai.GenerativeModel('gemini-pro')
|
23 |
+
return model
|
24 |
+
#########################
|
25 |
+
|
26 |
+
# Load the knowledge graph
|
27 |
+
graph_path = "./models/knowledge_graph.pkl" # Replace with the actual path to your knowledge graph
|
28 |
+
with open(graph_path, 'rb') as f:
|
29 |
+
graph_data = pickle.load(f)
|
30 |
+
|
31 |
+
knowledge_graph = nx.DiGraph()
|
32 |
+
knowledge_graph.add_nodes_from(graph_data['nodes'].items())
|
33 |
+
for u, edges in graph_data['edges'].items():
|
34 |
+
for v, data in edges.items():
|
35 |
+
knowledge_graph.add_edge(u, v, **data)
|
36 |
+
|
37 |
+
def predict_with_model(text):
|
38 |
+
"""Predict whether the news is real or fake using the ML model."""
|
39 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
40 |
+
with torch.no_grad():
|
41 |
+
outputs = model(**inputs)
|
42 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
43 |
+
predicted_label = torch.argmax(probabilities, dim=-1).item()
|
44 |
+
return "FAKE" if predicted_label == 1 else "REAL"
|
45 |
+
|
46 |
+
def update_knowledge_graph(text, is_real):
|
47 |
+
"""Update the knowledge graph with the new article."""
|
48 |
+
entities = extract_entities(text)
|
49 |
+
for entity, entity_type in entities:
|
50 |
+
if not knowledge_graph.has_node(entity):
|
51 |
+
knowledge_graph.add_node(
|
52 |
+
entity,
|
53 |
+
type=entity_type,
|
54 |
+
real_count=1 if is_real else 0,
|
55 |
+
fake_count=0 if is_real else 1
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
if is_real:
|
59 |
+
knowledge_graph.nodes[entity]['real_count'] += 1
|
60 |
+
else:
|
61 |
+
knowledge_graph.nodes[entity]['fake_count'] += 1
|
62 |
+
|
63 |
+
for i, (entity1, _) in enumerate(entities):
|
64 |
+
for entity2, _ in entities[i+1:]:
|
65 |
+
if not knowledge_graph.has_edge(entity1, entity2):
|
66 |
+
knowledge_graph.add_edge(
|
67 |
+
entity1,
|
68 |
+
entity2,
|
69 |
+
weight=1,
|
70 |
+
is_real=is_real
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
knowledge_graph[entity1][entity2]['weight'] += 1
|
74 |
+
|
75 |
+
def extract_entities(text):
|
76 |
+
"""Extract named entities from text using spaCy."""
|
77 |
+
doc = nlp(text)
|
78 |
+
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
79 |
+
return entities
|
80 |
+
|
81 |
+
def predict_with_knowledge_graph(text):
|
82 |
+
"""Predict whether the news is real or fake using the knowledge graph."""
|
83 |
+
entities = extract_entities(text)
|
84 |
+
real_score = 0
|
85 |
+
fake_score = 0
|
86 |
+
|
87 |
+
for entity, _ in entities:
|
88 |
+
if knowledge_graph.has_node(entity):
|
89 |
+
real_count = knowledge_graph.nodes[entity].get('real_count', 0)
|
90 |
+
fake_count = knowledge_graph.nodes[entity].get('fake_count', 0)
|
91 |
+
total = real_count + fake_count
|
92 |
+
if total > 0:
|
93 |
+
real_score += real_count / total
|
94 |
+
fake_score += fake_count / total
|
95 |
+
|
96 |
+
if real_score > fake_score:
|
97 |
+
return "REAL"
|
98 |
+
else:
|
99 |
+
return "FAKE"
|
100 |
+
|
101 |
+
def predict_news(text):
|
102 |
+
"""Predict whether the news is real or fake using both the ML model and the knowledge graph."""
|
103 |
+
# Predict with the ML model
|
104 |
+
ml_prediction = predict_with_model(text)
|
105 |
+
is_real = ml_prediction == "REAL"
|
106 |
+
|
107 |
+
# Update the knowledge graph
|
108 |
+
update_knowledge_graph(text, is_real)
|
109 |
+
|
110 |
+
# Predict with the knowledge graph
|
111 |
+
kg_prediction = predict_with_knowledge_graph(text)
|
112 |
+
|
113 |
+
# Combine predictions (for simplicity, we use the ML model's prediction here)
|
114 |
+
# You can enhance this by combining the scores from both predictions
|
115 |
+
return ml_prediction if ml_prediction == kg_prediction else "UNCERTAIN"
|
116 |
+
|
117 |
+
#########################
|
118 |
+
# def analyze_content_gemini(model, text):
|
119 |
+
# prompt = f"""Analyze this news text and provide results in the following JSON-like format:
|
120 |
+
|
121 |
+
# TEXT: {text}
|
122 |
+
|
123 |
+
# Please provide analysis in these specific sections:
|
124 |
+
|
125 |
+
# 1. GEMINI ANALYSIS:
|
126 |
+
# - Predicted Classification: [Real/Fake]
|
127 |
+
# - Confidence Score: [0-100%]
|
128 |
+
# - Reasoning: [Key points for classification]
|
129 |
+
|
130 |
+
# 2. TEXT CLASSIFICATION:
|
131 |
+
# - Content category/topic
|
132 |
+
# - Writing style: [Formal/Informal/Clickbait]
|
133 |
+
# - Target audience
|
134 |
+
# - Content type: [news/opinion/editorial]
|
135 |
+
|
136 |
+
# 3. SENTIMENT ANALYSIS:
|
137 |
+
# - Primary emotion
|
138 |
+
# - Emotional intensity (1-10)
|
139 |
+
# - Sensationalism Level: [High/Medium/Low]
|
140 |
+
# - Bias Indicators: [List if any]
|
141 |
+
# - Tone: (formal/informal), [Professional/Emotional/Neutral]
|
142 |
+
# - Key emotional triggers
|
143 |
+
|
144 |
+
# 4. ENTITY RECOGNITION:
|
145 |
+
# - Source Credibility: [High/Medium/Low]
|
146 |
+
# - People mentioned
|
147 |
+
# - Organizations
|
148 |
+
# - Locations
|
149 |
+
# - Dates/Time references
|
150 |
+
# - Key numbers/statistics
|
151 |
+
|
152 |
+
# 5. CONTEXT EXTRACTION:
|
153 |
+
# - Main narrative/story
|
154 |
+
# - Supporting elements
|
155 |
+
# - Key claims
|
156 |
+
# - Narrative structure
|
157 |
+
|
158 |
+
# 6. FACT CHECKING:
|
159 |
+
# - Verifiable Claims: [List main claims]
|
160 |
+
# - Evidence Present: [Yes/No]
|
161 |
+
# - Fact Check Score: [0-100%]
|
162 |
+
|
163 |
+
# Format the response clearly with distinct sections."""
|
164 |
+
|
165 |
+
# response = model.generate_content(prompt)
|
166 |
+
# return response.text
|
167 |
+
|
168 |
+
def analyze_content_gemini(model, text):
|
169 |
+
prompt = f"""Analyze this news text and return a JSON object with the following structure:
|
170 |
+
{{
|
171 |
+
"gemini_analysis": {{
|
172 |
+
"predicted_classification": "Real or Fake",
|
173 |
+
"confidence_score": "0-100",
|
174 |
+
"reasoning": ["point1", "point2"]
|
175 |
+
}},
|
176 |
+
"text_classification": {{
|
177 |
+
"category": "",
|
178 |
+
"writing_style": "Formal/Informal/Clickbait",
|
179 |
+
"target_audience": "",
|
180 |
+
"content_type": "news/opinion/editorial"
|
181 |
+
}},
|
182 |
+
"sentiment_analysis": {{
|
183 |
+
"primary_emotion": "",
|
184 |
+
"emotional_intensity": "1-10",
|
185 |
+
"sensationalism_level": "High/Medium/Low",
|
186 |
+
"bias_indicators": ["bias1", "bias2"],
|
187 |
+
"tone": {{"formality": "formal/informal", "style": "Professional/Emotional/Neutral"}},
|
188 |
+
"emotional_triggers": ["trigger1", "trigger2"]
|
189 |
+
}},
|
190 |
+
"entity_recognition": {{
|
191 |
+
"source_credibility": "High/Medium/Low",
|
192 |
+
"people": ["person1", "person2"],
|
193 |
+
"organizations": ["org1", "org2"],
|
194 |
+
"locations": ["location1", "location2"],
|
195 |
+
"dates": ["date1", "date2"],
|
196 |
+
"statistics": ["stat1", "stat2"]
|
197 |
+
}},
|
198 |
+
"context": {{
|
199 |
+
"main_narrative": "",
|
200 |
+
"supporting_elements": ["element1", "element2"],
|
201 |
+
"key_claims": ["claim1", "claim2"],
|
202 |
+
"narrative_structure": ""
|
203 |
+
}},
|
204 |
+
"fact_checking": {{
|
205 |
+
"verifiable_claims": ["claim1", "claim2"],
|
206 |
+
"evidence_present": "Yes/No",
|
207 |
+
"fact_check_score": "0-100"
|
208 |
+
}}
|
209 |
+
}}
|
210 |
+
|
211 |
+
Analyze this text and return only the JSON response: {text}"""
|
212 |
+
|
213 |
+
response = model.generate_content(prompt)
|
214 |
+
# return json.loads(response.text)
|
215 |
+
# Add error handling and response cleaning
|
216 |
+
try:
|
217 |
+
# Clean the response text to ensure it's valid JSON
|
218 |
+
cleaned_text = response.text.strip()
|
219 |
+
if cleaned_text.startswith('```json'):
|
220 |
+
cleaned_text = cleaned_text[7:-3] # Remove ```json and ``` markers
|
221 |
+
return json.loads(cleaned_text)
|
222 |
+
except json.JSONDecodeError:
|
223 |
+
# Return a default structured response if JSON parsing fails
|
224 |
+
return {
|
225 |
+
"gemini_analysis": {
|
226 |
+
"predicted_classification": "UNCERTAIN",
|
227 |
+
"confidence_score": "50",
|
228 |
+
"reasoning": ["Analysis failed to generate valid JSON"]
|
229 |
+
}
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
def clean_gemini_output(text):
|
234 |
+
"""Remove markdown formatting from Gemini output"""
|
235 |
+
text = text.replace('##', '')
|
236 |
+
text = text.replace('**', '')
|
237 |
+
return text
|
238 |
+
|
239 |
+
def get_gemini_analysis(text):
|
240 |
+
"""Get detailed content analysis from Gemini."""
|
241 |
+
gemini_model = setup_gemini()
|
242 |
+
gemini_analysis = analyze_content_gemini(gemini_model, text)
|
243 |
+
# cleaned_analysis = clean_gemini_output(gemini_analysis)
|
244 |
+
# return cleaned_analysis
|
245 |
+
return gemini_analysis
|
246 |
+
#########################
|
247 |
+
|
248 |
+
def main():
|
249 |
+
print("Welcome to the News Classifier!")
|
250 |
+
print("Enter your news text below. Type 'Exit' to quit.")
|
251 |
+
|
252 |
+
while True:
|
253 |
+
news_text = input("\nEnter news text: ")
|
254 |
+
|
255 |
+
if news_text.lower() == 'exit':
|
256 |
+
print("Thank you for using the News Classifier!")
|
257 |
+
return
|
258 |
+
|
259 |
+
# First get ML and Knowledge Graph prediction
|
260 |
+
prediction = predict_news(news_text)
|
261 |
+
print(f"\nML and Knowledge Graph Analysis: {prediction}")
|
262 |
+
|
263 |
+
# Then get Gemini analysis
|
264 |
+
print("\n=== Detailed Gemini Analysis ===")
|
265 |
+
gemini_result = get_gemini_analysis(news_text)
|
266 |
+
print(gemini_result)
|
267 |
+
|
268 |
+
|
269 |
+
if __name__ == "__main__":
|
270 |
+
main()
|
knowledge_graph_generator.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import networkx as nx
|
3 |
+
import spacy
|
4 |
+
import pickle
|
5 |
+
from datetime import datetime
|
6 |
+
import os
|
7 |
+
|
8 |
+
# Load spaCy for NER
|
9 |
+
nlp = spacy.load("en_core_web_sm")
|
10 |
+
|
11 |
+
class KnowledgeGraphBuilder:
|
12 |
+
def __init__(self, model_dir="models"):
|
13 |
+
self.model_dir = model_dir
|
14 |
+
self.knowledge_graph = nx.DiGraph()
|
15 |
+
|
16 |
+
def extract_entities(self, text):
|
17 |
+
"""Extract named entities from text using spaCy"""
|
18 |
+
try:
|
19 |
+
# Convert to string and handle NaN/None values
|
20 |
+
if pd.isna(text) or text is None:
|
21 |
+
return []
|
22 |
+
|
23 |
+
# Convert float or int to string if necessary
|
24 |
+
if isinstance(text, (float, int)):
|
25 |
+
text = str(text)
|
26 |
+
|
27 |
+
# Ensure text is a string
|
28 |
+
text = str(text).strip()
|
29 |
+
|
30 |
+
# Skip empty strings
|
31 |
+
if not text:
|
32 |
+
return []
|
33 |
+
|
34 |
+
doc = nlp(text)
|
35 |
+
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
36 |
+
return entities
|
37 |
+
except Exception as e:
|
38 |
+
print(f"Error processing text: {text}")
|
39 |
+
print(f"Error message: {str(e)}")
|
40 |
+
return []
|
41 |
+
|
42 |
+
def update_knowledge_graph(self, text, is_real):
|
43 |
+
"""Update knowledge graph with entities and their relationships"""
|
44 |
+
try:
|
45 |
+
entities = self.extract_entities(text)
|
46 |
+
|
47 |
+
# Skip if no entities were found
|
48 |
+
if not entities:
|
49 |
+
return
|
50 |
+
|
51 |
+
# Add nodes and edges to the graph
|
52 |
+
for entity, entity_type in entities:
|
53 |
+
# Add node if it doesn't exist
|
54 |
+
if not self.knowledge_graph.has_node(entity):
|
55 |
+
self.knowledge_graph.add_node(
|
56 |
+
entity,
|
57 |
+
type=entity_type,
|
58 |
+
real_count=1 if is_real else 0,
|
59 |
+
fake_count=0 if is_real else 1
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
# Update counts
|
63 |
+
if is_real:
|
64 |
+
self.knowledge_graph.nodes[entity]['real_count'] += 1
|
65 |
+
else:
|
66 |
+
self.knowledge_graph.nodes[entity]['fake_count'] += 1
|
67 |
+
|
68 |
+
# Add edges between entities in the same text
|
69 |
+
for i, (entity1, _) in enumerate(entities):
|
70 |
+
for entity2, _ in entities[i+1:]:
|
71 |
+
if not self.knowledge_graph.has_edge(entity1, entity2):
|
72 |
+
self.knowledge_graph.add_edge(
|
73 |
+
entity1,
|
74 |
+
entity2,
|
75 |
+
weight=1,
|
76 |
+
is_real=is_real
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
self.knowledge_graph[entity1][entity2]['weight'] += 1
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error updating knowledge graph: {str(e)}")
|
82 |
+
|
83 |
+
def save_knowledge_graph(self, filename=None):
|
84 |
+
"""Save the knowledge graph to a file"""
|
85 |
+
if filename is None:
|
86 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
87 |
+
filename = os.path.join(self.model_dir, f"knowledge_graph_{timestamp}.pkl")
|
88 |
+
|
89 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
90 |
+
|
91 |
+
# Convert the graph to a dictionary format for better serialization
|
92 |
+
graph_data = {
|
93 |
+
'nodes': dict(self.knowledge_graph.nodes(data=True)),
|
94 |
+
'edges': {}
|
95 |
+
}
|
96 |
+
|
97 |
+
# Properly format edges with their data
|
98 |
+
for u, v, data in self.knowledge_graph.edges(data=True):
|
99 |
+
if u not in graph_data['edges']:
|
100 |
+
graph_data['edges'][u] = {}
|
101 |
+
graph_data['edges'][u][v] = data
|
102 |
+
|
103 |
+
try:
|
104 |
+
with open(filename, 'wb') as f:
|
105 |
+
pickle.dump(graph_data, f)
|
106 |
+
print(f"Knowledge graph saved to {filename}")
|
107 |
+
print(f"Total nodes: {len(graph_data['nodes'])}")
|
108 |
+
print(f"Total edges: {sum(len(edges) for edges in graph_data['edges'].values())}")
|
109 |
+
return filename
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Error saving knowledge graph: {str(e)}")
|
112 |
+
return None
|
113 |
+
|
114 |
+
def get_graph_statistics(self):
|
115 |
+
"""Get basic statistics about the knowledge graph"""
|
116 |
+
stats = {
|
117 |
+
'total_nodes': self.knowledge_graph.number_of_nodes(),
|
118 |
+
'total_edges': self.knowledge_graph.number_of_edges(),
|
119 |
+
'entity_types': {},
|
120 |
+
'reliability_scores': {}
|
121 |
+
}
|
122 |
+
|
123 |
+
# Count entity types
|
124 |
+
for node, attrs in self.knowledge_graph.nodes(data=True):
|
125 |
+
entity_type = attrs.get('type', 'UNKNOWN')
|
126 |
+
stats['entity_types'][entity_type] = stats['entity_types'].get(entity_type, 0) + 1
|
127 |
+
|
128 |
+
# Calculate reliability score
|
129 |
+
real_count = attrs.get('real_count', 0)
|
130 |
+
fake_count = attrs.get('fake_count', 0)
|
131 |
+
total = real_count + fake_count
|
132 |
+
if total > 0:
|
133 |
+
reliability = real_count / total
|
134 |
+
stats['reliability_scores'][node] = reliability
|
135 |
+
|
136 |
+
return stats
|
137 |
+
|
138 |
+
def main():
|
139 |
+
# Initialize the knowledge graph builder
|
140 |
+
builder = KnowledgeGraphBuilder()
|
141 |
+
|
142 |
+
# Load your dataset
|
143 |
+
df = pd.read_csv('./combined.csv') # Replace with your actual data file
|
144 |
+
|
145 |
+
# Create knowledge graph
|
146 |
+
print("Building knowledge graph...")
|
147 |
+
total_rows = len(df)
|
148 |
+
for idx, row in df.iterrows():
|
149 |
+
try:
|
150 |
+
builder.update_knowledge_graph(row['text'], row['label'] == 'REAL')
|
151 |
+
if (idx + 1) % 100 == 0:
|
152 |
+
print(f"Processed {idx + 1}/{total_rows} entries ({(idx + 1)/total_rows*100:.1f}%)...")
|
153 |
+
except Exception as e:
|
154 |
+
print(f"Error processing row {idx}: {str(e)}")
|
155 |
+
continue
|
156 |
+
|
157 |
+
# Save the knowledge graph
|
158 |
+
graph_path = builder.save_knowledge_graph()
|
159 |
+
|
160 |
+
# Print statistics
|
161 |
+
stats = builder.get_graph_statistics()
|
162 |
+
print("\nKnowledge Graph Statistics:")
|
163 |
+
print(f"Total nodes: {stats['total_nodes']}")
|
164 |
+
print(f"Total edges: {stats['total_edges']}")
|
165 |
+
print("\nEntity types distribution:")
|
166 |
+
for entity_type, count in stats['entity_types'].items():
|
167 |
+
print(f"{entity_type}: {count}")
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
main()
|
models/knowledge_graph.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f6259a6e81cc6c739d239b3846fc112238e206f65f0999184c86e1539c43ab9
|
3 |
+
size 249881241
|
nlp_trainer.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.model_selection import train_test_split
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
from transformers import Trainer, TrainingArguments
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import torch
|
8 |
+
import re
|
9 |
+
import string
|
10 |
+
import logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
|
13 |
+
def load_dataset(path="./combined.csv"):
|
14 |
+
df = pd.read_csv(path, dtype={'text': str, 'label': str}) # Explicitly set dtypes
|
15 |
+
df = df.dropna() # Remove any null values
|
16 |
+
|
17 |
+
# Ensure consistent column names
|
18 |
+
if 'news' in df.columns:
|
19 |
+
df = df.rename(columns={"news": "text"})
|
20 |
+
if 'target' in df.columns:
|
21 |
+
df = df.rename(columns={"target": "label"})
|
22 |
+
|
23 |
+
# Convert labels to integers safely
|
24 |
+
label_map = {"real": 0, "fake": 1}
|
25 |
+
df['label'] = df['label'].str.lower().map(label_map)
|
26 |
+
|
27 |
+
# Drop any rows where label mapping failed
|
28 |
+
df = df.dropna(subset=['label'])
|
29 |
+
df['label'] = df['label'].astype(int)
|
30 |
+
|
31 |
+
X = df['text'].apply(str).tolist() # Ensure text is string
|
32 |
+
y = df['label'].tolist()
|
33 |
+
|
34 |
+
return train_test_split(X, y, test_size=0.2, random_state=42)
|
35 |
+
|
36 |
+
class NewsDataset(Dataset):
|
37 |
+
def __init__(self, texts, labels, tokenizer, max_len):
|
38 |
+
self.texts = texts
|
39 |
+
self.labels = labels
|
40 |
+
self.tokenizer = tokenizer
|
41 |
+
self.max_len = max_len
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.texts)
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
text = str(self.texts[idx])
|
48 |
+
encoding = self.tokenizer(
|
49 |
+
text,
|
50 |
+
max_length=self.max_len,
|
51 |
+
padding='max_length',
|
52 |
+
truncation=True,
|
53 |
+
return_tensors="pt"
|
54 |
+
)
|
55 |
+
return {
|
56 |
+
'input_ids': encoding['input_ids'].squeeze(0),
|
57 |
+
'attention_mask': encoding['attention_mask'].squeeze(0),
|
58 |
+
'labels': torch.tensor(int(self.labels[idx]), dtype=torch.long)
|
59 |
+
}
|
60 |
+
|
61 |
+
def train_model(train_texts, train_labels, val_texts, val_labels):
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
|
63 |
+
model = AutoModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-small', num_labels=2)
|
64 |
+
|
65 |
+
train_dataset = NewsDataset(train_texts, train_labels, tokenizer, max_len=128)
|
66 |
+
val_dataset = NewsDataset(val_texts, val_labels, tokenizer, max_len=128)
|
67 |
+
|
68 |
+
training_args = TrainingArguments(
|
69 |
+
output_dir='./results',
|
70 |
+
num_train_epochs=5,
|
71 |
+
per_device_train_batch_size=8,
|
72 |
+
per_device_eval_batch_size=8,
|
73 |
+
warmup_steps=500,
|
74 |
+
weight_decay=0.01,
|
75 |
+
logging_dir='./logs',
|
76 |
+
evaluation_strategy="epoch",
|
77 |
+
save_strategy="epoch"
|
78 |
+
)
|
79 |
+
|
80 |
+
trainer = Trainer(
|
81 |
+
model=model,
|
82 |
+
args=training_args,
|
83 |
+
train_dataset=train_dataset,
|
84 |
+
eval_dataset=val_dataset
|
85 |
+
)
|
86 |
+
|
87 |
+
trainer.train()
|
88 |
+
return tokenizer, model
|
89 |
+
|
90 |
+
def predict_news(tokenizer, model, news_text):
|
91 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
92 |
+
model.to(device)
|
93 |
+
print(device)
|
94 |
+
model.eval()
|
95 |
+
|
96 |
+
encoding = tokenizer(
|
97 |
+
str(news_text),
|
98 |
+
max_length=128,
|
99 |
+
padding='max_length',
|
100 |
+
truncation=True,
|
101 |
+
return_tensors="pt"
|
102 |
+
)
|
103 |
+
|
104 |
+
input_ids = encoding['input_ids'].to(device)
|
105 |
+
attention_mask = encoding['attention_mask'].to(device)
|
106 |
+
|
107 |
+
with torch.no_grad():
|
108 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
109 |
+
prediction = torch.argmax(outputs.logits, dim=1).item()
|
110 |
+
|
111 |
+
return "Fake" if prediction == 1 else "Real"
|
112 |
+
|
113 |
+
def main():
|
114 |
+
try:
|
115 |
+
X_train, X_test, y_train, y_test = load_dataset()
|
116 |
+
tokenizer, model = train_model(X_train, y_train, X_test, y_test)
|
117 |
+
|
118 |
+
while True:
|
119 |
+
user_input = input("\nEnter news text (or 'exit' to quit): ")
|
120 |
+
if user_input.lower() == 'exit':
|
121 |
+
break
|
122 |
+
result = predict_news(tokenizer, model, user_input)
|
123 |
+
print(f"The news is: {result}")
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
logging.error(f"An error occurred: {str(e)}")
|
127 |
+
raise
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
main()
|
package-lock.json
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "complete_nlp_stuff",
|
3 |
+
"lockfileVersion": 3,
|
4 |
+
"requires": true,
|
5 |
+
"packages": {
|
6 |
+
"": {
|
7 |
+
"dependencies": {
|
8 |
+
"axios": "^1.7.9"
|
9 |
+
}
|
10 |
+
},
|
11 |
+
"node_modules/asynckit": {
|
12 |
+
"version": "0.4.0",
|
13 |
+
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
|
14 |
+
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==",
|
15 |
+
"license": "MIT"
|
16 |
+
},
|
17 |
+
"node_modules/axios": {
|
18 |
+
"version": "1.7.9",
|
19 |
+
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.9.tgz",
|
20 |
+
"integrity": "sha512-LhLcE7Hbiryz8oMDdDptSrWowmB4Bl6RCt6sIJKpRB4XtVf0iEgewX3au/pJqm+Py1kCASkb/FFKjxQaLtxJvw==",
|
21 |
+
"license": "MIT",
|
22 |
+
"dependencies": {
|
23 |
+
"follow-redirects": "^1.15.6",
|
24 |
+
"form-data": "^4.0.0",
|
25 |
+
"proxy-from-env": "^1.1.0"
|
26 |
+
}
|
27 |
+
},
|
28 |
+
"node_modules/combined-stream": {
|
29 |
+
"version": "1.0.8",
|
30 |
+
"resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz",
|
31 |
+
"integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==",
|
32 |
+
"license": "MIT",
|
33 |
+
"dependencies": {
|
34 |
+
"delayed-stream": "~1.0.0"
|
35 |
+
},
|
36 |
+
"engines": {
|
37 |
+
"node": ">= 0.8"
|
38 |
+
}
|
39 |
+
},
|
40 |
+
"node_modules/delayed-stream": {
|
41 |
+
"version": "1.0.0",
|
42 |
+
"resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
|
43 |
+
"integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==",
|
44 |
+
"license": "MIT",
|
45 |
+
"engines": {
|
46 |
+
"node": ">=0.4.0"
|
47 |
+
}
|
48 |
+
},
|
49 |
+
"node_modules/follow-redirects": {
|
50 |
+
"version": "1.15.9",
|
51 |
+
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.9.tgz",
|
52 |
+
"integrity": "sha512-gew4GsXizNgdoRyqmyfMHyAmXsZDk6mHkSxZFCzW9gwlbtOW44CDtYavM+y+72qD/Vq2l550kMF52DT8fOLJqQ==",
|
53 |
+
"funding": [
|
54 |
+
{
|
55 |
+
"type": "individual",
|
56 |
+
"url": "https://github.com/sponsors/RubenVerborgh"
|
57 |
+
}
|
58 |
+
],
|
59 |
+
"license": "MIT",
|
60 |
+
"engines": {
|
61 |
+
"node": ">=4.0"
|
62 |
+
},
|
63 |
+
"peerDependenciesMeta": {
|
64 |
+
"debug": {
|
65 |
+
"optional": true
|
66 |
+
}
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"node_modules/form-data": {
|
70 |
+
"version": "4.0.1",
|
71 |
+
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.1.tgz",
|
72 |
+
"integrity": "sha512-tzN8e4TX8+kkxGPK8D5u0FNmjPUjw3lwC9lSLxxoB/+GtsJG91CO8bSWy73APlgAZzZbXEYZJuxjkHH2w+Ezhw==",
|
73 |
+
"license": "MIT",
|
74 |
+
"dependencies": {
|
75 |
+
"asynckit": "^0.4.0",
|
76 |
+
"combined-stream": "^1.0.8",
|
77 |
+
"mime-types": "^2.1.12"
|
78 |
+
},
|
79 |
+
"engines": {
|
80 |
+
"node": ">= 6"
|
81 |
+
}
|
82 |
+
},
|
83 |
+
"node_modules/mime-db": {
|
84 |
+
"version": "1.52.0",
|
85 |
+
"resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz",
|
86 |
+
"integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==",
|
87 |
+
"license": "MIT",
|
88 |
+
"engines": {
|
89 |
+
"node": ">= 0.6"
|
90 |
+
}
|
91 |
+
},
|
92 |
+
"node_modules/mime-types": {
|
93 |
+
"version": "2.1.35",
|
94 |
+
"resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz",
|
95 |
+
"integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==",
|
96 |
+
"license": "MIT",
|
97 |
+
"dependencies": {
|
98 |
+
"mime-db": "1.52.0"
|
99 |
+
},
|
100 |
+
"engines": {
|
101 |
+
"node": ">= 0.6"
|
102 |
+
}
|
103 |
+
},
|
104 |
+
"node_modules/proxy-from-env": {
|
105 |
+
"version": "1.1.0",
|
106 |
+
"resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz",
|
107 |
+
"integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==",
|
108 |
+
"license": "MIT"
|
109 |
+
}
|
110 |
+
}
|
111 |
+
}
|
package.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dependencies": {
|
3 |
+
"axios": "^1.7.9"
|
4 |
+
}
|
5 |
+
}
|
results/checkpoint-5030/config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "microsoft/deberta-v3-small",
|
3 |
+
"architectures": [
|
4 |
+
"DebertaV2ForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 768,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 3072,
|
12 |
+
"layer_norm_eps": 1e-07,
|
13 |
+
"max_position_embeddings": 512,
|
14 |
+
"max_relative_positions": -1,
|
15 |
+
"model_type": "deberta-v2",
|
16 |
+
"norm_rel_ebd": "layer_norm",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 6,
|
19 |
+
"pad_token_id": 0,
|
20 |
+
"pooler_dropout": 0,
|
21 |
+
"pooler_hidden_act": "gelu",
|
22 |
+
"pooler_hidden_size": 768,
|
23 |
+
"pos_att_type": [
|
24 |
+
"p2c",
|
25 |
+
"c2p"
|
26 |
+
],
|
27 |
+
"position_biased_input": false,
|
28 |
+
"position_buckets": 256,
|
29 |
+
"relative_attention": true,
|
30 |
+
"share_att_key": true,
|
31 |
+
"torch_dtype": "float32",
|
32 |
+
"transformers_version": "4.46.2",
|
33 |
+
"type_vocab_size": 0,
|
34 |
+
"vocab_size": 128100
|
35 |
+
}
|
results/checkpoint-5030/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f34f9b72aa96cb0927c5cfcdad25c0281212e297d61dd14dcacdb68138c40840
|
3 |
+
size 567598552
|
results/checkpoint-5030/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cddba7c9ed0694f75f418657613b8400183c22b1e86f0d5fac90de0153d72e5f
|
3 |
+
size 1135260474
|
results/checkpoint-5030/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d0c9d10259d2c7407ae8f630db471aed45598cb19d4fec8b8a17555906525a5
|
3 |
+
size 14244
|
results/checkpoint-5030/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f0b07a36064ffcbc9c9cdc658bf6076e72b04ada218a099af03a6b74a3518d1
|
3 |
+
size 1064
|
results/checkpoint-5030/trainer_state.json
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 5.0,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 5030,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.4970178926441352,
|
13 |
+
"grad_norm": 11.328213691711426,
|
14 |
+
"learning_rate": 5e-05,
|
15 |
+
"loss": 0.3471,
|
16 |
+
"step": 500
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"epoch": 0.9940357852882704,
|
20 |
+
"grad_norm": 0.29149460792541504,
|
21 |
+
"learning_rate": 4.448123620309051e-05,
|
22 |
+
"loss": 0.1462,
|
23 |
+
"step": 1000
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"epoch": 1.0,
|
27 |
+
"eval_loss": 0.14880910515785217,
|
28 |
+
"eval_runtime": 32.5193,
|
29 |
+
"eval_samples_per_second": 61.871,
|
30 |
+
"eval_steps_per_second": 7.749,
|
31 |
+
"step": 1006
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"epoch": 1.4910536779324055,
|
35 |
+
"grad_norm": 0.04432953894138336,
|
36 |
+
"learning_rate": 3.896247240618102e-05,
|
37 |
+
"loss": 0.0738,
|
38 |
+
"step": 1500
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"epoch": 1.9880715705765408,
|
42 |
+
"grad_norm": 0.004722778219729662,
|
43 |
+
"learning_rate": 3.3443708609271526e-05,
|
44 |
+
"loss": 0.0599,
|
45 |
+
"step": 2000
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"epoch": 2.0,
|
49 |
+
"eval_loss": 0.17704755067825317,
|
50 |
+
"eval_runtime": 32.4526,
|
51 |
+
"eval_samples_per_second": 61.998,
|
52 |
+
"eval_steps_per_second": 7.765,
|
53 |
+
"step": 2012
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"epoch": 2.485089463220676,
|
57 |
+
"grad_norm": 0.0014285552315413952,
|
58 |
+
"learning_rate": 2.792494481236203e-05,
|
59 |
+
"loss": 0.0176,
|
60 |
+
"step": 2500
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"epoch": 2.982107355864811,
|
64 |
+
"grad_norm": 0.0008603875176049769,
|
65 |
+
"learning_rate": 2.240618101545254e-05,
|
66 |
+
"loss": 0.026,
|
67 |
+
"step": 3000
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"epoch": 3.0,
|
71 |
+
"eval_loss": 0.16322186589241028,
|
72 |
+
"eval_runtime": 32.2403,
|
73 |
+
"eval_samples_per_second": 62.406,
|
74 |
+
"eval_steps_per_second": 7.816,
|
75 |
+
"step": 3018
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"epoch": 3.4791252485089466,
|
79 |
+
"grad_norm": 0.000587798363994807,
|
80 |
+
"learning_rate": 1.688741721854305e-05,
|
81 |
+
"loss": 0.0042,
|
82 |
+
"step": 3500
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"epoch": 3.9761431411530817,
|
86 |
+
"grad_norm": 0.00033068188349716365,
|
87 |
+
"learning_rate": 1.1368653421633555e-05,
|
88 |
+
"loss": 0.0012,
|
89 |
+
"step": 4000
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"epoch": 4.0,
|
93 |
+
"eval_loss": 0.20389850437641144,
|
94 |
+
"eval_runtime": 33.2829,
|
95 |
+
"eval_samples_per_second": 60.452,
|
96 |
+
"eval_steps_per_second": 7.571,
|
97 |
+
"step": 4024
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"epoch": 4.473161033797217,
|
101 |
+
"grad_norm": 0.0048806252889335155,
|
102 |
+
"learning_rate": 5.8498896247240626e-06,
|
103 |
+
"loss": 0.0013,
|
104 |
+
"step": 4500
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"epoch": 4.970178926441352,
|
108 |
+
"grad_norm": 0.00042022508569061756,
|
109 |
+
"learning_rate": 3.3112582781456954e-07,
|
110 |
+
"loss": 0.0006,
|
111 |
+
"step": 5000
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"epoch": 5.0,
|
115 |
+
"eval_loss": 0.19458653032779694,
|
116 |
+
"eval_runtime": 33.1006,
|
117 |
+
"eval_samples_per_second": 60.784,
|
118 |
+
"eval_steps_per_second": 7.613,
|
119 |
+
"step": 5030
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"logging_steps": 500,
|
123 |
+
"max_steps": 5030,
|
124 |
+
"num_input_tokens_seen": 0,
|
125 |
+
"num_train_epochs": 5,
|
126 |
+
"save_steps": 500,
|
127 |
+
"stateful_callbacks": {
|
128 |
+
"TrainerControl": {
|
129 |
+
"args": {
|
130 |
+
"should_epoch_stop": false,
|
131 |
+
"should_evaluate": false,
|
132 |
+
"should_log": false,
|
133 |
+
"should_save": true,
|
134 |
+
"should_training_stop": true
|
135 |
+
},
|
136 |
+
"attributes": {}
|
137 |
+
}
|
138 |
+
},
|
139 |
+
"total_flos": 1332007138928640.0,
|
140 |
+
"train_batch_size": 8,
|
141 |
+
"trial_name": null,
|
142 |
+
"trial_params": null
|
143 |
+
}
|
results/checkpoint-5030/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e34c99e352dd9e22706f7f1143f42ff1385e64d6b188ee3ed83ab034094c017
|
3 |
+
size 5240
|