File size: 2,666 Bytes
c513ab7
ecd5ef4
c513ab7
 
ecd5ef4
c513ab7
 
 
 
 
 
ecd5ef4
c513ab7
 
 
ecd5ef4
c513ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ecd5ef4
c513ab7
 
 
 
ecd5ef4
c513ab7
 
ecd5ef4
c513ab7
 
 
ecd5ef4
c513ab7
 
 
 
 
 
 
 
ecd5ef4
c513ab7
 
ecd5ef4
 
c513ab7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pandas as pd
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer, util

file_path = 'data/korean_drama.csv'
def load_data(file_path):
    """Load and preprocess the Korean drama data."""
    df = pd.read_csv(file_path)
    df['synopsis'] = df['synopsis'].astype(str)  
    return df

def create_embeddings(model, texts):
    """Create embeddings for the given texts using the provided model."""
    return model.encode(texts, convert_to_tensor=True)

def chatbot(query, model, drama_embeddings, korean_drama_df):
    """Chatbot function to find the most relevant dramas based on the query."""
    query_embedding = model.encode(query, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(query_embedding, drama_embeddings)[0]
    
    # Find the top 3 most similar dramas
    top_results = torch.topk(cosine_scores, k=3)
    
    response = ""
    for score, idx in zip(top_results[0], top_results[1]):
        row = korean_drama_df.iloc[idx.item()]
        response += f"**{row['drama_name']}**\n"
        response += f"Year: {row['year']}\n"
        response += f"Director: {row['director']}\n"
        response += f"Screenwriter: {row['screenwriter']}\n"
        response += f"Country: {row['country']}\n"
        response += f"Type: {row['type']}\n"
        response += f"Total Episodes: {row['tot_eps']}\n"
        response += f"Duration: {row['duration']} minutes\n"
        response += f"Start Date: {row['start_dt']}\n"
        response += f"End Date: {row['end_dt']}\n"
        response += f"Aired On: {row['aired_on']}\n"
        response += f"Original Network: {row['org_net']}\n"
        response += f"Content Rating: {row['content_rt']}\n"
        response += f"Synopsis: {row['synopsis']}\n"
        response += f"Rank: {row['rank']}\n"
        response += f"Popularity: {row['pop']}\n"
        response += "\n\n"
    
    return response

def main():
    # Load data
    data_file = 'data/korean_drama.csv'
    korean_drama_df = load_data(data_file)

    # Load the pre-trained SentenceTransformer model
    model = SentenceTransformer('all-MiniLM-L6-v2')

    # Create embeddings for the drama synopses
    drama_synopses = korean_drama_df['synopsis'].tolist()
    drama_embeddings = create_embeddings(model, drama_synopses)

    # Define the Gradio interface
    iface = gr.Interface(
        fn=lambda query: chatbot(query, model, drama_embeddings, korean_drama_df),
        inputs="text",
        outputs="text",
        title="Korean Drama Chatbot",
        description="Ask me about Korean dramas!"
    )

    # Launch the interface
    iface.launch()

if __name__ == "__main__":
    main()