ConsumeWise / api /claims_analysis.py
shivanis14's picture
Update api/claims_analysis.py
f37f9cb verified
import sys
import os
import json
from typing import List, Dict, Any
import torch
from openai import OpenAI
from fastapi import FastAPI
app = FastAPI()
def create_assistant(client):
assistant3 = client.beta.assistants.create(
name="Misleading Claims",
instructions="You are an expert dietician. Use your knowledge base to answer questions about the misleading claims about food product.",
model="gpt-4o",
tools=[{"type": "file_search"}],
temperature=0,
top_p = 0.85
)
# Create a vector store
vector_store3 = client.beta.vector_stores.create(name="Misleading Claims Vec")
# Ready the files for upload to OpenAI
file_paths = ["docs/MisLeading_Claims.docx"]
file_streams = [open(path, "rb") for path in file_paths]
# Use the upload and poll SDK helper to upload the files, add them to the vector store,
# and poll the status of the file batch for completion.
file_batch3 = client.beta.vector_stores.file_batches.upload_and_poll(
vector_store_id=vector_store3.id, files=file_streams
)
#Misleading Claims
assistant3 = client.beta.assistants.update(
assistant_id=assistant3.id,
tool_resources={"file_search": {"vector_store_ids": [vector_store3.id]}},
)
return assistant3
def analyze_claims(claims, ingredients, assistant_id, client):
thread = client.beta.threads.create(
messages=[
{
"role": "user",
"content": "A food product named has the following claims: " + ', '.join(claims) + " and ingredients: " + ', '.join(ingredients) + """. Please evaluate the validity of each claim as well as assess if the product name is misleading.
The output must be in JSON format as follows:
{
<claim_name>: {
'Verdict': <A judgment on the claim's accuracy, ranging from 'Accurate' to varying degrees of 'Misleading'>,
'Why?': <A concise, bulleted summary explaining the specific ingredients or aspects contributing to the discrepancy>,
'Detailed Analysis': <An in-depth explanation of the claim, incorporating relevant regulatory guidelines and health perspectives to support the verdict>
}
}
"""
}
]
)
run = client.beta.threads.runs.create_and_poll(
thread_id=thread.id,
assistant_id=assistant_id,
include=["step_details.tool_calls[*].file_search.results[*].content"]
)
# Polling loop to wait for a response in the thread
messages = []
max_retries = 10 # You can set a maximum retry limit
retries = 0
wait_time = 2 # Seconds to wait between retries
while retries < max_retries:
messages = list(client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id))
if messages: # If we receive any messages, break the loop
break
retries += 1
time.sleep(wait_time)
# Check if we got the message content
if not messages:
raise TimeoutError("Processing Claims : No messages were returned after polling.")
message_content = messages[0].content[0].text
annotations = message_content.annotations
#citations = []
#print(f"Length of annotations is {len(annotations)}")
for index, annotation in enumerate(annotations):
if file_citation := getattr(annotation, "file_citation", None):
#cited_file = client.files.retrieve(file_citation.file_id)
#citations.append(f"[{index}] {cited_file.filename}")
message_content.value = message_content.value.replace(annotation.text, "")
#if debug_mode:
# claims_not_found_in_doc = []
# print(message_content.value)
# for key, value in json.loads(message_content.value.replace("```", "").replace("json", "")).items():
# if value.startswith("(NOT FOUND IN DOCUMENT)"):
# claims_not_found_in_doc.append(key)
# print(f"Claims not found in the doc are {','.join(claims_not_found_in_doc)}")
#claims_analysis = json.loads(message_content.value.replace("```", "").replace("json", "").replace("(NOT FOUND IN DOCUMENT) ", ""))
claims_analysis = {}
if message_content.value != "":
claims_analysis = json.loads(message_content.value.replace("```", "").replace("json", ""))
claims_analysis_str = ""
for key, value in claims_analysis.items():
claims_analysis_str += f"{key}: {value}\n"
return claims_analysis_str
@app.post("/claims-analysis")
def get_claims_analysis(product_info_from_db):
if product_info_from_db:
brand_name = product_info_from_db.get("brandName", "")
product_name = product_info_from_db.get("productName", "")
claims_list = product_info_from_db.get("claims", [])
ingredients_list = [ingredient["name"] for ingredient in product_info_from_db.get("ingredients", [])]
claims_analysis = ""
if len(claims_list) > 0:
#Create client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
#Create assistant for processing level
assistant_c = create_assistant(client)
#Create embeddings
claims_analysis = analyze_claims(claims_list, ingredients_list, assistant_c.id, client) if claims_list else ""
print(f"Returning claims_analysis : {claims_analysis}")
return {'claims_analysis' : claims_analysis}