Spaces:
Runtime error
Runtime error
File size: 5,174 Bytes
486cd93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, GPT2Tokenizer
import os
# Define paths and model identifiers for easy reference and maintenance
filename = "output_country_details.txt" # Filename for stored country details
retrieval_model_name = 'output/sentence-transformer-finetuned/'
gpt2_model_name = "gpt2" # Identifier for the GPT-2 model used
tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
# Load models and handle potential failures gracefully
try:
retrieval_model = SentenceTransformer(retrieval_model_name)
gpt_model = pipeline("text-generation", model=gpt2_model_name)
print("Models loaded successfully.")
except Exception as e:
print(f"Failed to load models: {e}")
def load_and_preprocess_text(filename):
"""
Load text data from a file and preprocess it by stripping whitespace and ignoring empty lines.
Args:
filename (str): Path to the file containing text data.
Returns:
list of str: Preprocessed lines of text from the file.
"""
try:
with open(filename, 'r', encoding='utf-8') as file:
segments = [line.strip() for line in file if line.strip()]
print("Text loaded and preprocessed successfully.")
return segments
except Exception as e:
print(f"Failed to load or preprocess text: {e}")
return []
segments = load_and_preprocess_text(filename)
def find_relevant_segment(user_query, segments):
"""
Identify the most relevant text segment from a list based on a user's query using sentence embeddings.
Args:
user_query (str): User's input query.
segments (list of str): List of text segments to search from.
Returns:
str: The text segment that best matches the query.
"""
try:
query_embedding = retrieval_model.encode(user_query)
segment_embeddings = retrieval_model.encode(segments)
similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
best_idx = similarities.argmax()
print("Relevant segment found:", segments[best_idx])
return segments[best_idx]
except Exception as e:
print(f"Error finding relevant segment: {e}")
return ""
def generate_response(user_query, relevant_segment):
"""
Generate a response to a user's query using a text generation model based on a relevant text segment.
Args:
user_query (str): The user's query.
relevant_segment (str): The segment of text that is relevant to the query.
Returns:
str: A response generated from the model.
"""
try:
prompt = f"Thank you for your question! This is an additional fact about your topic: {relevant_segment}"
max_tokens = len(tokenizer(prompt)['input_ids']) + 50
response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
response_cleaned = clean_up_response(response, relevant_segment)
return response_cleaned
except Exception as e:
print(f"Error generating response: {e}")
return ""
def clean_up_response(response, segments):
"""
Clean and format the generated response by removing empty sentences and repetitive parts.
Args:
response (str): The raw response generated by the model.
segments (str): The text segment used to generate the response.
Returns:
str: Cleaned and formatted response.
"""
sentences = response.split('.')
cleaned_sentences = []
for sentence in sentences:
if sentence.strip() and sentence.strip() not in segments and sentence.strip() not in cleaned_sentences:
cleaned_sentences.append(sentence.strip())
cleaned_response = '. '.join(cleaned_sentences).strip()
if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
cleaned_response += "."
return cleaned_response
# Gradio interface and application logic
def query_model(question):
"""
Process a question through the model and return the response.
Args:
question (str): The question submitted by the user.
Returns:
str: Generated response or welcome message if no question is provided.
"""
if question == "":
return welcome_message
relevant_segment = find_relevant_segment(question, segments)
response = generate_response(question, relevant_segment)
return response
with gr.Blocks() as demo:
gr.Markdown(welcome_message)
with gr.Row():
with gr.Column():
gr.Markdown(topics)
with gr.Column():
gr.Markdown(countries)
with gr.Row():
img = gr.Image(os.path.join(os.getcwd(), "final.png"), width=500)
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?")
answer = gr.Textbox(label="VisaBot Response", placeholder="VisaBot will respond here...", interactive=False, lines=10)
submit_button = gr.Button("Submit")
submit_button.click(fn=query_model, inputs=question, outputs=answer)
demo.launch()
|