File size: 3,532 Bytes
d903275
0a2c880
2ff483e
727c299
2ff483e
d903275
 
 
 
2ff483e
7b6b4a2
727c299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b6b4a2
727c299
 
7b6b4a2
727c299
 
 
 
97b7fcf
dd0962f
5bd8a1a
 
dd0962f
727c299
dd0962f
727c299
 
5bd8a1a
 
727c299
5bd8a1a
7b6b4a2
0a2c880
 
 
 
 
 
 
3b25749
727c299
dd0962f
7b6b4a2
 
 
 
0a2c880
 
 
 
5bf06ec
0a2c880
 
 
 
 
 
 
 
b99a30a
 
 
 
 
0a2c880
8139f95
 
b99a30a
8139f95
b99a30a
5bd8a1a
8139f95
 
dd0962f
b99a30a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import requests

model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

def get_movie_info(movie_title):
    api_key = "20e959f0f28e6b3e3de49c50f358538a"
    search_url = f"https://api.themoviedb.org/3/search/movie"
    
    # Make a search query to TMDb
    params = {
        "api_key": api_key,
        "query": movie_title,
        "language": "en-US",
        "page": 1,
    }

    try:
        search_response = requests.get(search_url, params=params)
        search_data = search_response.json()

        # Check if any results are found
        if search_data.get("results"):
            movie_id = search_data["results"][0]["id"]

            # Fetch detailed information using the movie ID
            details_url = f"https://api.themoviedb.org/3/movie/{movie_id}"
            details_params = {
                "api_key": api_key,
                "language": "en-US",
            }

            details_response = requests.get(details_url, params=details_params)
            details_data = details_response.json()

            # Extract relevant information
            title = details_data.get("title", "Unknown Title")
            year = details_data.get("release_date", "Unknown Year")[:4]
            genre = ", ".join(genre["name"] for genre in details_data.get("genres", []))
            tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}"
            poster_path = details_data.get("poster_path")
            
            # Convert poster_path to a complete image URL
            image_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else ""

            return f"Title: {title}, Year: {year}, Genre: {genre}\nFind more info here: {tmdb_link}", image_url

        else:
            return "Movie not found", ""

    except Exception as e:
        return f"Error: {e}", ""

def generate_response(prompt):
    input_text_template = (
        "A chat between a curious user and an artificial intelligence assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions. "
        f"USER: {prompt} "
        "ASSISTANT:"
    )

    # Call the get_movie_info function to enrich the response
    movie_info, image_url = get_movie_info(prompt)

    # Concatenate the movie info with the input template
    input_text_template += f" Movie Info: {movie_info}"

    model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)

    gen_conf = {
        "top_k": 20,
        "max_length": 20,
        "temperature": 0.6,
        "do_sample": True,
        "eos_token_id": tokenizer.eos_token_id,
    }

    output = model.generate(**model_inputs, **gen_conf)

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Display image directly in the chat
    image_component = gr.Image(image_url, alt="Movie Poster")
    
    return f"Movie Info:\n{movie_info}\n\nGenerated Response:\n{generated_text}\n", image_component

# Define chat function for gr.ChatInterface
def chat_function(message, history):
    response, image_component = generate_response(message)
    history.append([message, response])
    return response, image_component

# Create Gradio Chat Interface
chat_interface = gr.ChatInterface(chat_function)
chat_interface.launch()