|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import spacy |
|
import google.generativeai as genai |
|
import json |
|
import os |
|
import dotenv |
|
|
|
dotenv.load_dotenv() |
|
|
|
|
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
model_path = "./results/checkpoint-753" |
|
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small') |
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
model.eval() |
|
|
|
def setup_gemini(): |
|
genai.configure(api_key=os.getenv("GEMINI_API")) |
|
model = genai.GenerativeModel('gemini-pro') |
|
return model |
|
|
|
def predict_with_model(text): |
|
"""Predict whether the news is real or fake using the ML model.""" |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
predicted_label = torch.argmax(probabilities, dim=-1).item() |
|
return "FAKE" if predicted_label == 1 else "REAL" |
|
|
|
def extract_entities(text): |
|
"""Extract named entities from text using spaCy.""" |
|
doc = nlp(text) |
|
entities = [(ent.text, ent.label_) for ent in doc.ents] |
|
return entities |
|
|
|
def predict_news(text): |
|
"""Predict whether the news is real or fake using the ML model.""" |
|
|
|
prediction = predict_with_model(text) |
|
return prediction |
|
|
|
def analyze_content_gemini(model, text): |
|
prompt = f"""Analyze this news text and return a JSON object with the following structure: |
|
{{ |
|
"gemini_analysis": {{ |
|
"predicted_classification": "Real or Fake", |
|
"confidence_score": "0-100", |
|
"reasoning": ["point1", "point2"] |
|
}}, |
|
"text_classification": {{ |
|
"category": "", |
|
"writing_style": "Formal/Informal/Clickbait", |
|
"target_audience": "", |
|
"content_type": "news/opinion/editorial" |
|
}}, |
|
"sentiment_analysis": {{ |
|
"primary_emotion": "", |
|
"emotional_intensity": "1-10", |
|
"sensationalism_level": "High/Medium/Low", |
|
"bias_indicators": ["bias1", "bias2"], |
|
"tone": {{"formality": "formal/informal", "style": "Professional/Emotional/Neutral"}}, |
|
"emotional_triggers": ["trigger1", "trigger2"] |
|
}}, |
|
"entity_recognition": {{ |
|
"source_credibility": "High/Medium/Low", |
|
"people": ["person1", "person2"], |
|
"organizations": ["org1", "org2"], |
|
"locations": ["location1", "location2"], |
|
"dates": ["date1", "date2"], |
|
"statistics": ["stat1", "stat2"] |
|
}}, |
|
"context": {{ |
|
"main_narrative": "", |
|
"supporting_elements": ["element1", "element2"], |
|
"key_claims": ["claim1", "claim2"], |
|
"narrative_structure": "" |
|
}}, |
|
"fact_checking": {{ |
|
"verifiable_claims": ["claim1", "claim2"], |
|
"evidence_present": "Yes/No", |
|
"fact_check_score": "0-100" |
|
}} |
|
}} |
|
|
|
Analyze this text and return only the JSON response: {text}""" |
|
|
|
response = model.generate_content(prompt) |
|
try: |
|
cleaned_text = response.text.strip() |
|
if cleaned_text.startswith('```json'): |
|
cleaned_text = cleaned_text[7:-3] |
|
return json.loads(cleaned_text) |
|
except json.JSONDecodeError: |
|
return { |
|
"gemini_analysis": { |
|
"predicted_classification": "UNCERTAIN", |
|
"confidence_score": "50", |
|
"reasoning": ["Analysis failed to generate valid JSON"] |
|
} |
|
} |
|
|
|
def clean_gemini_output(text): |
|
"""Remove markdown formatting from Gemini output""" |
|
text = text.replace('##', '') |
|
text = text.replace('**', '') |
|
return text |
|
|
|
def get_gemini_analysis(text): |
|
"""Get detailed content analysis from Gemini.""" |
|
gemini_model = setup_gemini() |
|
gemini_analysis = analyze_content_gemini(gemini_model, text) |
|
return gemini_analysis |
|
|
|
def main(): |
|
print("Welcome to the News Classifier!") |
|
print("Enter your news text below. Type 'Exit' to quit.") |
|
|
|
while True: |
|
news_text = input("\nEnter news text: ") |
|
|
|
if news_text.lower() == 'exit': |
|
print("Thank you for using the News Classifier!") |
|
return |
|
|
|
|
|
prediction = predict_news(news_text) |
|
print(f"\nML Analysis: {prediction}") |
|
|
|
|
|
print("\n=== Detailed Gemini Analysis ===") |
|
gemini_result = get_gemini_analysis(news_text) |
|
print(gemini_result) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|