misinfo / app.py
gyigit's picture
update
54e8a79
import streamlit as st
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import pandas as pd
import os
from evaluate import MisinformationPredictor
from src.evidence.im2im_retrieval import ImageCorpus
from src.evidence.text2text_retrieval import SemanticSimilarity
from src.utils.path_utils import get_project_root
from typing import List, Optional, Tuple
from dataclasses import dataclass
# Initialize BLIP model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
)
PROJECT_ROOT = get_project_root()
@dataclass
class Evidence:
evidence_id: str
dataset: str
text: Optional[str]
image: Optional[Image.Image]
caption: Optional[str]
image_path: Optional[str]
classification_result_all: Optional[Tuple[str, str, str, str]] = None
classification_result_final: Optional[str] = None
CLASSIFICATION_CATEGORIES = ["support", "refute", "not_enough_information"]
def generate_caption(image: Image.Image) -> str:
"""Generates a caption for a given image."""
try:
with st.spinner("Generating caption..."):
inputs = processor(image, return_tensors="pt")
output = model.generate(**inputs)
return processor.decode(output[0], skip_special_tokens=True)
except Exception as e:
st.error(f"Error generating caption: {e}")
return ""
def enrich_text_with_caption(text: str, image_caption: str) -> str:
"""Appends the image caption to the given text."""
if image_caption:
return f"{text}. {image_caption}"
return text
@st.cache_data
def get_train_df():
data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed")
train_csv_path = os.path.join(data_dir, "train_enriched.csv")
return pd.read_csv(train_csv_path)
@st.cache_data
def get_test_df():
data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed")
train_csv_path = os.path.join(data_dir, "test_enriched.csv")
return pd.read_csv(train_csv_path)
@st.cache_data
def get_semantic_similarity(
train_embeddings_file: str,
test_embeddings_file: str,
train_df: pd.DataFrame,
test_df: pd.DataFrame,
):
return SemanticSimilarity(
train_embeddings_file=train_embeddings_file,
test_embeddings_file=test_embeddings_file,
train_df=train_df,
test_df=test_df,
)
def retrieve_evidences_by_text(
query: str,
top_k: int = 5,
) -> List[Evidence]:
"""
Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity.
Args:
query (str): The query text to perform the search.
top_k (int): Number of top results to retrieve.
Returns:
List[Evidence]: A list of retrieved evidence objects.
"""
train_embeddings_file = os.path.join(PROJECT_ROOT, "train_embeddings.h5")
test_embeddings_file = os.path.join(PROJECT_ROOT, "test_embeddings.h5")
similarity = get_semantic_similarity(
train_embeddings_file=train_embeddings_file,
test_embeddings_file=test_embeddings_file,
train_df=get_train_df(),
test_df=get_test_df(),
)
evidences = []
try:
# Perform semantic search across both train and test datasets
results = similarity.search(query=query, top_k=top_k)
# Retrieve evidence rows based on the search results
for evidence_id, score in results:
# Determine whether the ID belongs to train or test set
if evidence_id.startswith("train_"):
df = similarity.train_csv
elif evidence_id.startswith("test_"):
df = similarity.test_csv
else:
continue # Skip invalid IDs
# Extract the row by ID
row = df[df["id"] == int(evidence_id.split("_")[1])].iloc[0]
evidence_text = row.get("evidence_enriched")
evidence_image_caption = row.get("evidence_image_caption")
evidence_image_path = row.get("evidence_image")
evidence_image = None
full_image_path = None
# Load the image if a valid path is provided
if pd.notna(evidence_image_path):
full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path)
try:
evidence_image = Image.open(full_image_path).convert("RGB")
except Exception as e:
st.error(f"Failed to load image {evidence_image_path}: {e}")
evidence_id_number = evidence_id.split("_")[1]
evidence_dataset = evidence_id.split("_")[0]
# Create an Evidence object
evidences.append(
Evidence(
text=evidence_text,
image=evidence_image,
caption=evidence_image_caption,
evidence_id=evidence_id_number,
dataset=evidence_dataset,
image_path=full_image_path,
)
)
except Exception as e:
st.error(f"Error performing semantic search: {e}")
return evidences
@st.cache_data
def get_image_corpus(image_features):
return ImageCorpus(image_features)
def retrieve_evidences_by_image(
image_path: str,
top_k: int = 5,
) -> List[Evidence]:
"""
Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity.
Args:
query (str): The query text to perform the search.
top_k (int): Number of top results to retrieve.
Returns:
List[Evidence]: A list of retrieved evidence objects.
"""
image_features = os.path.join(PROJECT_ROOT, "evidence_features.pkl")
image_corpus = get_image_corpus(image_features)
evidences = []
try:
# Perform semantic search across both train and test datasets
results = image_corpus.retrieve_similar_images(image_path, top_k=top_k)
# Retrieve evidence rows based on the search results
for evidence_path, score in results:
evidence_id = evidence_path.split("/")[-1]
evidence_id_number = evidence_id.split("_")[0]
# Determine whether the ID belongs to train or test set
if "train" in evidence_path:
df = get_train_df()
elif "test" in evidence_path:
df = get_test_df()
else:
continue # Skip invalid IDs
# Extract the row by ID
row = df[df["id"] == int(evidence_id_number)].iloc[0]
evidence_text = row.get("evidence_enriched")
evidence_image_caption = row.get("evidence_image_caption")
evidence_image_path = row.get("evidence_image")
evidence_image = None
full_image_path = None
# Load the image if a valid path is provided
if pd.notna(evidence_image_path):
full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path)
try:
evidence_image = Image.open(full_image_path).convert("RGB")
except Exception as e:
st.error(f"Failed to load image {evidence_image_path}: {e}")
# Create an Evidence object
evidences.append(
Evidence(
text=evidence_text,
image=evidence_image,
caption=evidence_image_caption,
dataset=evidence_path.split("/")[-2],
evidence_id=evidence_id_number,
image_path=full_image_path,
)
)
except Exception as e:
st.error(f"Error performing semantic search: {e}")
return evidences
@st.cache_resource
def get_predictor():
return MisinformationPredictor(model_path="ckpts/model.pt", device="cpu")
def classify_evidence(
claim_text: str, claim_image_path: str, evidence_text: str, evidence_image_path: str
) -> Tuple[str, str, str, str]:
"""Assigns a random classification to each evidence."""
predictor = get_predictor()
predictions = predictor.evaluate(
claim_text, claim_image_path, evidence_text, evidence_image_path
)
if predictions:
return (
predictions.get("text_text", "not_enough_information"),
predictions.get("text_image", "not_enough_information"),
predictions.get("image_text", "not_enough_information"),
predictions.get("image_image", "not_enough_information"),
)
else:
return (
"not_enough_information",
"not_enough_information",
"not_enough_information",
"not_enough_information",
)
def display_evidence_tab(evidences: List[Evidence], tab_label: str):
"""Displays evidence in a tabbed format."""
with st.container():
for index, evidence in enumerate(evidences):
with st.container():
st.subheader(f"Evidence {index + 1}")
st.write(f"Evidence Dataset: {evidence.dataset}")
st.write(f"Evidence ID: {evidence.evidence_id}")
if evidence.image:
st.image(
evidence.image,
caption="Evidence Image",
use_container_width=True,
)
st.text_area(
"Evidence Caption",
value=evidence.caption or "No caption available.",
height=100,
key=f"caption_{tab_label}_{index}",
disabled=True,
)
st.text_area(
"Evidence Text",
value=evidence.text or "No text available.",
height=100,
key=f"text_{tab_label}_{index}",
disabled=True,
)
if evidence.classification_result_all:
st.write("**Classification:**")
st.write(f"**text|text:** {evidence.classification_result_all[0]}")
st.write(f"**text|image:** {evidence.classification_result_all[1]}")
st.write(f"**image|text:** {evidence.classification_result_all[2]}")
st.write(
f"**image|image:** {evidence.classification_result_all[3]}"
)
st.write(
f"**Final classification result:** {evidence.classification_result_final}"
)
def get_final_classification(results: Tuple[str, str, str, str]) -> str:
text_text = results[0]
text_image = results[1]
image_text = results[2]
image_image = results[3]
# Helper function to determine the final classification based on two inputs
def resolve_classification(val1: str, val2: str) -> str:
if val1 == val2 and val1 in {"support", "refute"}:
return val1
if (val1 in {"support", "refute"} and val2 == "not_enough_information") or (
val2 in {"support", "refute"} and val1 == "not_enough_information"
):
return val1 if val1 != "not_enough_information" else val2
return "not_enough_information"
# Step 1: Check text_text and image_image
final_result = resolve_classification(text_text, image_image)
if final_result != "not_enough_information":
return final_result
# Step 2: Check text_image and image_text
final_result = resolve_classification(text_image, image_text)
if final_result != "not_enough_information":
return final_result
# Step 3: If still undetermined, return "not_enough_information"
return "not_enough_information"
def main():
st.title("Multimodal Evidence-Based Misinformation Classification")
st.write("Upload claims that have image and/or text content to verify.")
# File uploader for images
uploaded_image = st.file_uploader(
"Upload an image (1 max)", type=["jpg", "jpeg", "png"], key="image_uploader"
)
if uploaded_image:
try:
image = Image.open(uploaded_image).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
except Exception as e:
st.error(f"Failed to display the image: {e}")
# Text input field
input_text = st.text_area("Enter text (max 4096 characters)", "", max_chars=4096)
# Sliders for top_k values
col1, col2 = st.columns(2)
with col1:
top_k_text = st.slider(
"Top-k Text Evidences", min_value=1, max_value=5, value=2, key="top_k_text"
)
with col2:
top_k_image = st.slider(
"Top-k Image Evidences",
min_value=1,
max_value=5,
value=2,
key="top_k_image",
)
# Generate Enriched Text button
if st.button("Verify Claim"):
if not uploaded_image and not input_text:
st.warning("Please upload an image or enter text.")
return
progress = st.progress(0)
# Step 1: Generate caption
progress.progress(10)
st.write("### Step 1: Generating caption...")
image_caption = ""
if uploaded_image:
image_caption = generate_caption(image)
st.write("**Generated Image Caption:**", image_caption)
# Step 2: Enrich text
progress.progress(40)
st.write("### Step 2: Enriching text...")
enriched_text = enrich_text_with_caption(input_text, image_caption)
st.write("**Enriched Text:**")
st.write(enriched_text)
# Step 3: Retrieve evidences by text
progress.progress(50)
st.write("### Step 3: Retrieving evidences by text...")
if input_text:
text_evidences = retrieve_evidences_by_text(enriched_text, top_k=top_k_text)
st.write(f"Retrieved {len(text_evidences)} text evidences.")
else:
text_evidences = None
st.write("Text modality is missing from the input claim!")
# Step 4: Retrieve evidences by image
progress.progress(70)
st.write("### Step 4: Retrieving evidences by image...")
if uploaded_image:
image_evidences = retrieve_evidences_by_image(
uploaded_image, top_k=top_k_image
)
st.write(f"Retrieved {len(image_evidences)} image evidences.")
else:
image_evidences = None
st.write("Image modality is missing from the input claim!")
# Step 5: Classify evidences
progress.progress(90)
st.write("### Step 5: Verifying claim with retrieved evidences...")
for evidence in (text_evidences or []) + (image_evidences or []):
a, b, c, d = classify_evidence(
claim_text=enriched_text,
claim_image_path=uploaded_image,
evidence_text=evidence.text,
evidence_image_path=evidence.image_path,
)
evidence.classification_result_all = a, b, c, d
evidence.classification_result_final = get_final_classification(
evidence.classification_result_all
)
# Step 6: Display evidences
progress.progress(100)
if text_evidences or image_evidences:
st.write("## Results")
tabs = st.tabs(["Text Evidences", "Image Evidences"])
with tabs[0]:
if text_evidences:
st.write("### Text Evidences")
display_evidence_tab(text_evidences, "text")
else:
st.write("Text modality is missing from the input claim!")
with tabs[1]:
if image_evidences:
st.write("### Image Evidences")
display_evidence_tab(image_evidences, "image")
else:
st.write("Image modality is missing from the input claim!")
if __name__ == "__main__":
main()