import gradio as gr from sentence_transformers import SentenceTransformer, util from transformers import pipeline, GPT2Tokenizer import os # Define paths and model identifiers for easy reference and maintenance filename = "output_country_details.txt" # Filename for stored country details retrieval_model_name = 'output/sentence-transformer-finetuned/' gpt2_model_name = "gpt2" # Identifier for the GPT-2 model used tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name) # Load models and handle potential failures gracefully try: retrieval_model = SentenceTransformer(retrieval_model_name) gpt_model = pipeline("text-generation", model=gpt2_model_name) print("Models loaded successfully.") except Exception as e: print(f"Failed to load models: {e}") def load_and_preprocess_text(filename): """ Load text data from a file and preprocess it by stripping whitespace and ignoring empty lines. Args: filename (str): Path to the file containing text data. Returns: list of str: Preprocessed lines of text from the file. """ try: with open(filename, 'r', encoding='utf-8') as file: segments = [line.strip() for line in file if line.strip()] print("Text loaded and preprocessed successfully.") return segments except Exception as e: print(f"Failed to load or preprocess text: {e}") return [] segments = load_and_preprocess_text(filename) def find_relevant_segment(user_query, segments): """ Identify the most relevant text segment from a list based on a user's query using sentence embeddings. Args: user_query (str): User's input query. segments (list of str): List of text segments to search from. Returns: str: The text segment that best matches the query. """ try: query_embedding = retrieval_model.encode(user_query) segment_embeddings = retrieval_model.encode(segments) similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0] best_idx = similarities.argmax() print("Relevant segment found:", segments[best_idx]) return segments[best_idx] except Exception as e: print(f"Error finding relevant segment: {e}") return "" def generate_response(user_query, relevant_segment): """ Generate a response to a user's query using a text generation model based on a relevant text segment. Args: user_query (str): The user's query. relevant_segment (str): The segment of text that is relevant to the query. Returns: str: A response generated from the model. """ try: prompt = f"Thank you for your question! This is an additional fact about your topic: {relevant_segment}" max_tokens = len(tokenizer(prompt)['input_ids']) + 50 response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text'] response_cleaned = clean_up_response(response, relevant_segment) return response_cleaned except Exception as e: print(f"Error generating response: {e}") return "" def clean_up_response(response, segments): """ Clean and format the generated response by removing empty sentences and repetitive parts. Args: response (str): The raw response generated by the model. segments (str): The text segment used to generate the response. Returns: str: Cleaned and formatted response. """ sentences = response.split('.') cleaned_sentences = [] for sentence in sentences: if sentence.strip() and sentence.strip() not in segments and sentence.strip() not in cleaned_sentences: cleaned_sentences.append(sentence.strip()) cleaned_response = '. '.join(cleaned_sentences).strip() if cleaned_response and not cleaned_response.endswith((".", "!", "?")): cleaned_response += "." return cleaned_response # Gradio interface and application logic def query_model(question): """ Process a question through the model and return the response. Args: question (str): The question submitted by the user. Returns: str: Generated response or welcome message if no question is provided. """ if question == "": return welcome_message relevant_segment = find_relevant_segment(question, segments) response = generate_response(question, relevant_segment) return response with gr.Blocks() as demo: gr.Markdown(welcome_message) with gr.Row(): with gr.Column(): gr.Markdown(topics) with gr.Column(): gr.Markdown(countries) with gr.Row(): img = gr.Image(os.path.join(os.getcwd(), "final.png"), width=500) with gr.Row(): with gr.Column(): question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?") answer = gr.Textbox(label="VisaBot Response", placeholder="VisaBot will respond here...", interactive=False, lines=10) submit_button = gr.Button("Submit") submit_button.click(fn=query_model, inputs=question, outputs=answer) demo.launch()