File size: 3,532 Bytes
d903275 0a2c880 2ff483e 727c299 2ff483e d903275 2ff483e 7b6b4a2 727c299 7b6b4a2 727c299 7b6b4a2 727c299 97b7fcf dd0962f 5bd8a1a dd0962f 727c299 dd0962f 727c299 5bd8a1a 727c299 5bd8a1a 7b6b4a2 0a2c880 3b25749 727c299 dd0962f 7b6b4a2 0a2c880 5bf06ec 0a2c880 b99a30a 0a2c880 8139f95 b99a30a 8139f95 b99a30a 5bd8a1a 8139f95 dd0962f b99a30a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import requests
model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
def get_movie_info(movie_title):
api_key = "20e959f0f28e6b3e3de49c50f358538a"
search_url = f"https://api.themoviedb.org/3/search/movie"
# Make a search query to TMDb
params = {
"api_key": api_key,
"query": movie_title,
"language": "en-US",
"page": 1,
}
try:
search_response = requests.get(search_url, params=params)
search_data = search_response.json()
# Check if any results are found
if search_data.get("results"):
movie_id = search_data["results"][0]["id"]
# Fetch detailed information using the movie ID
details_url = f"https://api.themoviedb.org/3/movie/{movie_id}"
details_params = {
"api_key": api_key,
"language": "en-US",
}
details_response = requests.get(details_url, params=details_params)
details_data = details_response.json()
# Extract relevant information
title = details_data.get("title", "Unknown Title")
year = details_data.get("release_date", "Unknown Year")[:4]
genre = ", ".join(genre["name"] for genre in details_data.get("genres", []))
tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}"
poster_path = details_data.get("poster_path")
# Convert poster_path to a complete image URL
image_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else ""
return f"Title: {title}, Year: {year}, Genre: {genre}\nFind more info here: {tmdb_link}", image_url
else:
return "Movie not found", ""
except Exception as e:
return f"Error: {e}", ""
def generate_response(prompt):
input_text_template = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
f"USER: {prompt} "
"ASSISTANT:"
)
# Call the get_movie_info function to enrich the response
movie_info, image_url = get_movie_info(prompt)
# Concatenate the movie info with the input template
input_text_template += f" Movie Info: {movie_info}"
model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)
gen_conf = {
"top_k": 20,
"max_length": 20,
"temperature": 0.6,
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
}
output = model.generate(**model_inputs, **gen_conf)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Display image directly in the chat
image_component = gr.Image(image_url, alt="Movie Poster")
return f"Movie Info:\n{movie_info}\n\nGenerated Response:\n{generated_text}\n", image_component
# Define chat function for gr.ChatInterface
def chat_function(message, history):
response, image_component = generate_response(message)
history.append([message, response])
return response, image_component
# Create Gradio Chat Interface
chat_interface = gr.ChatInterface(chat_function)
chat_interface.launch()
|