File size: 5,174 Bytes
486cd93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, GPT2Tokenizer
import os

# Define paths and model identifiers for easy reference and maintenance
filename = "output_country_details.txt"  # Filename for stored country details
retrieval_model_name = 'output/sentence-transformer-finetuned/'
gpt2_model_name = "gpt2"  # Identifier for the GPT-2 model used
tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)

# Load models and handle potential failures gracefully
try:
    retrieval_model = SentenceTransformer(retrieval_model_name)
    gpt_model = pipeline("text-generation", model=gpt2_model_name)
    print("Models loaded successfully.")
except Exception as e:
    print(f"Failed to load models: {e}")

def load_and_preprocess_text(filename):
    """
    Load text data from a file and preprocess it by stripping whitespace and ignoring empty lines.
    
    Args:
    filename (str): Path to the file containing text data.
    
    Returns:
    list of str: Preprocessed lines of text from the file.
    """
    try:
        with open(filename, 'r', encoding='utf-8') as file:
            segments = [line.strip() for line in file if line.strip()]
        print("Text loaded and preprocessed successfully.")
        return segments
    except Exception as e:
        print(f"Failed to load or preprocess text: {e}")
        return []

segments = load_and_preprocess_text(filename)

def find_relevant_segment(user_query, segments):
    """
    Identify the most relevant text segment from a list based on a user's query using sentence embeddings.
    
    Args:
    user_query (str): User's input query.
    segments (list of str): List of text segments to search from.
    
    Returns:
    str: The text segment that best matches the query.
    """
    try:
        query_embedding = retrieval_model.encode(user_query)
        segment_embeddings = retrieval_model.encode(segments)
        similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
        best_idx = similarities.argmax()
        print("Relevant segment found:", segments[best_idx])
        return segments[best_idx]
    except Exception as e:
        print(f"Error finding relevant segment: {e}")
        return ""

def generate_response(user_query, relevant_segment):
    """
    Generate a response to a user's query using a text generation model based on a relevant text segment.
    
    Args:
    user_query (str): The user's query.
    relevant_segment (str): The segment of text that is relevant to the query.
    
    Returns:
    str: A response generated from the model.
    """
    try:
        prompt = f"Thank you for your question! This is an additional fact about your topic: {relevant_segment}"
        max_tokens = len(tokenizer(prompt)['input_ids']) + 50
        response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
        response_cleaned = clean_up_response(response, relevant_segment)
        return response_cleaned
    except Exception as e:
        print(f"Error generating response: {e}")
        return ""

def clean_up_response(response, segments):
    """
    Clean and format the generated response by removing empty sentences and repetitive parts.
    
    Args:
    response (str): The raw response generated by the model.
    segments (str): The text segment used to generate the response.
    
    Returns:
    str: Cleaned and formatted response.
    """
    sentences = response.split('.')
    cleaned_sentences = []
    for sentence in sentences:
        if sentence.strip() and sentence.strip() not in segments and sentence.strip() not in cleaned_sentences:
            cleaned_sentences.append(sentence.strip())
    cleaned_response = '. '.join(cleaned_sentences).strip()
    if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
        cleaned_response += "."
    return cleaned_response

# Gradio interface and application logic
def query_model(question):
    """
    Process a question through the model and return the response.
    
    Args:
    question (str): The question submitted by the user.
    
    Returns:
    str: Generated response or welcome message if no question is provided.
    """
    if question == "":
        return welcome_message
    relevant_segment = find_relevant_segment(question, segments)
    response = generate_response(question, relevant_segment)
    return response

with gr.Blocks() as demo:
    gr.Markdown(welcome_message)
    with gr.Row():
        with gr.Column():
            gr.Markdown(topics)
        with gr.Column():
            gr.Markdown(countries)
    with gr.Row():
        img = gr.Image(os.path.join(os.getcwd(), "final.png"), width=500)
    with gr.Row():
        with gr.Column():
            question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?")
            answer = gr.Textbox(label="VisaBot Response", placeholder="VisaBot will respond here...", interactive=False, lines=10)
            submit_button = gr.Button("Submit")
            submit_button.click(fn=query_model, inputs=question, outputs=answer)

demo.launch()