|
import streamlit as st |
|
from PIL import Image |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
import pandas as pd |
|
import os |
|
|
|
from evaluate import MisinformationPredictor |
|
from src.evidence.im2im_retrieval import ImageCorpus |
|
from src.evidence.text2text_retrieval import SemanticSimilarity |
|
from src.utils.path_utils import get_project_root |
|
from typing import List, Optional, Tuple |
|
from dataclasses import dataclass |
|
|
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
|
model = BlipForConditionalGeneration.from_pretrained( |
|
"Salesforce/blip-image-captioning-large" |
|
) |
|
|
|
PROJECT_ROOT = get_project_root() |
|
|
|
|
|
@dataclass |
|
class Evidence: |
|
evidence_id: str |
|
dataset: str |
|
text: Optional[str] |
|
image: Optional[Image.Image] |
|
caption: Optional[str] |
|
image_path: Optional[str] |
|
classification_result_all: Optional[Tuple[str, str, str, str]] = None |
|
classification_result_final: Optional[str] = None |
|
|
|
|
|
CLASSIFICATION_CATEGORIES = ["support", "refute", "not_enough_information"] |
|
|
|
|
|
def generate_caption(image: Image.Image) -> str: |
|
"""Generates a caption for a given image.""" |
|
try: |
|
with st.spinner("Generating caption..."): |
|
inputs = processor(image, return_tensors="pt") |
|
output = model.generate(**inputs) |
|
return processor.decode(output[0], skip_special_tokens=True) |
|
except Exception as e: |
|
st.error(f"Error generating caption: {e}") |
|
return "" |
|
|
|
|
|
def enrich_text_with_caption(text: str, image_caption: str) -> str: |
|
"""Appends the image caption to the given text.""" |
|
if image_caption: |
|
return f"{text}. {image_caption}" |
|
return text |
|
|
|
|
|
@st.cache_data |
|
def get_train_df(): |
|
data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed") |
|
train_csv_path = os.path.join(data_dir, "train_enriched.csv") |
|
return pd.read_csv(train_csv_path) |
|
|
|
|
|
@st.cache_data |
|
def get_test_df(): |
|
data_dir = os.path.join(PROJECT_ROOT, "data", "preprocessed") |
|
train_csv_path = os.path.join(data_dir, "test_enriched.csv") |
|
return pd.read_csv(train_csv_path) |
|
|
|
|
|
@st.cache_data |
|
def get_semantic_similarity( |
|
train_embeddings_file: str, |
|
test_embeddings_file: str, |
|
train_df: pd.DataFrame, |
|
test_df: pd.DataFrame, |
|
): |
|
return SemanticSimilarity( |
|
train_embeddings_file=train_embeddings_file, |
|
test_embeddings_file=test_embeddings_file, |
|
train_df=train_df, |
|
test_df=test_df, |
|
) |
|
|
|
|
|
def retrieve_evidences_by_text( |
|
query: str, |
|
top_k: int = 5, |
|
) -> List[Evidence]: |
|
""" |
|
Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity. |
|
|
|
Args: |
|
query (str): The query text to perform the search. |
|
top_k (int): Number of top results to retrieve. |
|
|
|
Returns: |
|
List[Evidence]: A list of retrieved evidence objects. |
|
""" |
|
train_embeddings_file = os.path.join(PROJECT_ROOT, "train_embeddings.h5") |
|
test_embeddings_file = os.path.join(PROJECT_ROOT, "test_embeddings.h5") |
|
similarity = get_semantic_similarity( |
|
train_embeddings_file=train_embeddings_file, |
|
test_embeddings_file=test_embeddings_file, |
|
train_df=get_train_df(), |
|
test_df=get_test_df(), |
|
) |
|
evidences = [] |
|
try: |
|
|
|
results = similarity.search(query=query, top_k=top_k) |
|
|
|
|
|
for evidence_id, score in results: |
|
|
|
if evidence_id.startswith("train_"): |
|
df = similarity.train_csv |
|
elif evidence_id.startswith("test_"): |
|
df = similarity.test_csv |
|
else: |
|
continue |
|
|
|
|
|
row = df[df["id"] == int(evidence_id.split("_")[1])].iloc[0] |
|
evidence_text = row.get("evidence_enriched") |
|
evidence_image_caption = row.get("evidence_image_caption") |
|
evidence_image_path = row.get("evidence_image") |
|
evidence_image = None |
|
full_image_path = None |
|
|
|
|
|
if pd.notna(evidence_image_path): |
|
full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path) |
|
try: |
|
evidence_image = Image.open(full_image_path).convert("RGB") |
|
except Exception as e: |
|
st.error(f"Failed to load image {evidence_image_path}: {e}") |
|
|
|
evidence_id_number = evidence_id.split("_")[1] |
|
evidence_dataset = evidence_id.split("_")[0] |
|
|
|
|
|
evidences.append( |
|
Evidence( |
|
text=evidence_text, |
|
image=evidence_image, |
|
caption=evidence_image_caption, |
|
evidence_id=evidence_id_number, |
|
dataset=evidence_dataset, |
|
image_path=full_image_path, |
|
) |
|
) |
|
except Exception as e: |
|
st.error(f"Error performing semantic search: {e}") |
|
|
|
return evidences |
|
|
|
|
|
@st.cache_data |
|
def get_image_corpus(image_features): |
|
return ImageCorpus(image_features) |
|
|
|
|
|
def retrieve_evidences_by_image( |
|
image_path: str, |
|
top_k: int = 5, |
|
) -> List[Evidence]: |
|
""" |
|
Retrieves evidence rows from preloaded embeddings and CSV data using semantic similarity. |
|
|
|
Args: |
|
query (str): The query text to perform the search. |
|
top_k (int): Number of top results to retrieve. |
|
|
|
Returns: |
|
List[Evidence]: A list of retrieved evidence objects. |
|
""" |
|
image_features = os.path.join(PROJECT_ROOT, "evidence_features.pkl") |
|
image_corpus = get_image_corpus(image_features) |
|
evidences = [] |
|
try: |
|
|
|
results = image_corpus.retrieve_similar_images(image_path, top_k=top_k) |
|
|
|
|
|
for evidence_path, score in results: |
|
evidence_id = evidence_path.split("/")[-1] |
|
evidence_id_number = evidence_id.split("_")[0] |
|
|
|
if "train" in evidence_path: |
|
df = get_train_df() |
|
elif "test" in evidence_path: |
|
df = get_test_df() |
|
else: |
|
continue |
|
|
|
|
|
row = df[df["id"] == int(evidence_id_number)].iloc[0] |
|
evidence_text = row.get("evidence_enriched") |
|
evidence_image_caption = row.get("evidence_image_caption") |
|
evidence_image_path = row.get("evidence_image") |
|
evidence_image = None |
|
full_image_path = None |
|
|
|
|
|
if pd.notna(evidence_image_path): |
|
full_image_path = os.path.join(PROJECT_ROOT, evidence_image_path) |
|
try: |
|
evidence_image = Image.open(full_image_path).convert("RGB") |
|
except Exception as e: |
|
st.error(f"Failed to load image {evidence_image_path}: {e}") |
|
|
|
|
|
evidences.append( |
|
Evidence( |
|
text=evidence_text, |
|
image=evidence_image, |
|
caption=evidence_image_caption, |
|
dataset=evidence_path.split("/")[-2], |
|
evidence_id=evidence_id_number, |
|
image_path=full_image_path, |
|
) |
|
) |
|
except Exception as e: |
|
st.error(f"Error performing semantic search: {e}") |
|
|
|
return evidences |
|
|
|
|
|
@st.cache_resource |
|
def get_predictor(): |
|
return MisinformationPredictor(model_path="ckpts/model.pt", device="cpu") |
|
|
|
|
|
def classify_evidence( |
|
claim_text: str, claim_image_path: str, evidence_text: str, evidence_image_path: str |
|
) -> Tuple[str, str, str, str]: |
|
"""Assigns a random classification to each evidence.""" |
|
predictor = get_predictor() |
|
predictions = predictor.evaluate( |
|
claim_text, claim_image_path, evidence_text, evidence_image_path |
|
) |
|
if predictions: |
|
return ( |
|
predictions.get("text_text", "not_enough_information"), |
|
predictions.get("text_image", "not_enough_information"), |
|
predictions.get("image_text", "not_enough_information"), |
|
predictions.get("image_image", "not_enough_information"), |
|
) |
|
else: |
|
return ( |
|
"not_enough_information", |
|
"not_enough_information", |
|
"not_enough_information", |
|
"not_enough_information", |
|
) |
|
|
|
|
|
def display_evidence_tab(evidences: List[Evidence], tab_label: str): |
|
"""Displays evidence in a tabbed format.""" |
|
with st.container(): |
|
for index, evidence in enumerate(evidences): |
|
with st.container(): |
|
st.subheader(f"Evidence {index + 1}") |
|
st.write(f"Evidence Dataset: {evidence.dataset}") |
|
st.write(f"Evidence ID: {evidence.evidence_id}") |
|
if evidence.image: |
|
st.image( |
|
evidence.image, |
|
caption="Evidence Image", |
|
use_container_width=True, |
|
) |
|
st.text_area( |
|
"Evidence Caption", |
|
value=evidence.caption or "No caption available.", |
|
height=100, |
|
key=f"caption_{tab_label}_{index}", |
|
disabled=True, |
|
) |
|
st.text_area( |
|
"Evidence Text", |
|
value=evidence.text or "No text available.", |
|
height=100, |
|
key=f"text_{tab_label}_{index}", |
|
disabled=True, |
|
) |
|
if evidence.classification_result_all: |
|
st.write("**Classification:**") |
|
st.write(f"**text|text:** {evidence.classification_result_all[0]}") |
|
st.write(f"**text|image:** {evidence.classification_result_all[1]}") |
|
st.write(f"**image|text:** {evidence.classification_result_all[2]}") |
|
st.write( |
|
f"**image|image:** {evidence.classification_result_all[3]}" |
|
) |
|
st.write( |
|
f"**Final classification result:** {evidence.classification_result_final}" |
|
) |
|
|
|
|
|
def get_final_classification(results: Tuple[str, str, str, str]) -> str: |
|
text_text = results[0] |
|
text_image = results[1] |
|
image_text = results[2] |
|
image_image = results[3] |
|
|
|
|
|
def resolve_classification(val1: str, val2: str) -> str: |
|
if val1 == val2 and val1 in {"support", "refute"}: |
|
return val1 |
|
if (val1 in {"support", "refute"} and val2 == "not_enough_information") or ( |
|
val2 in {"support", "refute"} and val1 == "not_enough_information" |
|
): |
|
return val1 if val1 != "not_enough_information" else val2 |
|
return "not_enough_information" |
|
|
|
|
|
final_result = resolve_classification(text_text, image_image) |
|
if final_result != "not_enough_information": |
|
return final_result |
|
|
|
|
|
final_result = resolve_classification(text_image, image_text) |
|
if final_result != "not_enough_information": |
|
return final_result |
|
|
|
|
|
return "not_enough_information" |
|
|
|
|
|
def main(): |
|
st.title("Multimodal Evidence-Based Misinformation Classification") |
|
st.write("Upload claims that have image and/or text content to verify.") |
|
|
|
|
|
uploaded_image = st.file_uploader( |
|
"Upload an image (1 max)", type=["jpg", "jpeg", "png"], key="image_uploader" |
|
) |
|
|
|
if uploaded_image: |
|
try: |
|
image = Image.open(uploaded_image).convert("RGB") |
|
st.image(image, caption="Uploaded Image", use_container_width=True) |
|
except Exception as e: |
|
st.error(f"Failed to display the image: {e}") |
|
|
|
|
|
input_text = st.text_area("Enter text (max 4096 characters)", "", max_chars=4096) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
top_k_text = st.slider( |
|
"Top-k Text Evidences", min_value=1, max_value=5, value=2, key="top_k_text" |
|
) |
|
with col2: |
|
top_k_image = st.slider( |
|
"Top-k Image Evidences", |
|
min_value=1, |
|
max_value=5, |
|
value=2, |
|
key="top_k_image", |
|
) |
|
|
|
|
|
if st.button("Verify Claim"): |
|
if not uploaded_image and not input_text: |
|
st.warning("Please upload an image or enter text.") |
|
return |
|
|
|
progress = st.progress(0) |
|
|
|
|
|
progress.progress(10) |
|
st.write("### Step 1: Generating caption...") |
|
image_caption = "" |
|
if uploaded_image: |
|
image_caption = generate_caption(image) |
|
st.write("**Generated Image Caption:**", image_caption) |
|
|
|
|
|
progress.progress(40) |
|
st.write("### Step 2: Enriching text...") |
|
enriched_text = enrich_text_with_caption(input_text, image_caption) |
|
st.write("**Enriched Text:**") |
|
st.write(enriched_text) |
|
|
|
|
|
progress.progress(50) |
|
st.write("### Step 3: Retrieving evidences by text...") |
|
if input_text: |
|
text_evidences = retrieve_evidences_by_text(enriched_text, top_k=top_k_text) |
|
st.write(f"Retrieved {len(text_evidences)} text evidences.") |
|
else: |
|
text_evidences = None |
|
st.write("Text modality is missing from the input claim!") |
|
|
|
|
|
progress.progress(70) |
|
st.write("### Step 4: Retrieving evidences by image...") |
|
if uploaded_image: |
|
image_evidences = retrieve_evidences_by_image( |
|
uploaded_image, top_k=top_k_image |
|
) |
|
st.write(f"Retrieved {len(image_evidences)} image evidences.") |
|
else: |
|
image_evidences = None |
|
st.write("Image modality is missing from the input claim!") |
|
|
|
|
|
progress.progress(90) |
|
st.write("### Step 5: Verifying claim with retrieved evidences...") |
|
for evidence in (text_evidences or []) + (image_evidences or []): |
|
a, b, c, d = classify_evidence( |
|
claim_text=enriched_text, |
|
claim_image_path=uploaded_image, |
|
evidence_text=evidence.text, |
|
evidence_image_path=evidence.image_path, |
|
) |
|
evidence.classification_result_all = a, b, c, d |
|
evidence.classification_result_final = get_final_classification( |
|
evidence.classification_result_all |
|
) |
|
|
|
|
|
progress.progress(100) |
|
if text_evidences or image_evidences: |
|
st.write("## Results") |
|
tabs = st.tabs(["Text Evidences", "Image Evidences"]) |
|
|
|
with tabs[0]: |
|
if text_evidences: |
|
st.write("### Text Evidences") |
|
display_evidence_tab(text_evidences, "text") |
|
else: |
|
st.write("Text modality is missing from the input claim!") |
|
|
|
with tabs[1]: |
|
if image_evidences: |
|
st.write("### Image Evidences") |
|
display_evidence_tab(image_evidences, "image") |
|
else: |
|
st.write("Image modality is missing from the input claim!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|