chrismontes's picture
Update app.py
d9e57d3 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import chromadb
import gradio as gr
# Determine the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("./SMOL")
model = AutoModelForCausalLM.from_pretrained("./SMOL").to(device)
# Initialize the sentence transformer model
smodel = SentenceTransformer('./embed')
# Initialize the chromadb client and collection
client = chromadb.PersistentClient(path="vectordb")
collection = client.get_or_create_collection("dogdb")
def clean_text_block(text):
start_keyword = "'documents': [["
end_keyword = "]], 'uris':"
start_index = text.find(start_keyword)
end_index = text.find(end_keyword) + len(end_keyword)
if start_index != -1 and end_index != -1:
cleaned_text = text[start_index + len(start_keyword):end_index - len(end_keyword)]
return cleaned_text
else:
return "Keywords not found in the text."
def remove_unwanted_parts(text):
start_keyword = "system"
end_keyword = """Respond in a friendly manner; you are an informational about dogs.
assistant"""
start_idx = text.find(start_keyword)
end_idx = text.find(end_keyword) + len(end_keyword)
if start_idx != -1 and end_idx != -1:
cleaned_text = text[:start_idx] + text[end_idx:]
return cleaned_text.strip()
else:
return text
def generate_response(question):
query = [{'question': f"{question}?"}]
query_embeddings = smodel.encode(query)
results = collection.query(
query_embeddings=query_embeddings,
n_results=3 # how many results to return
)
results = clean_text_block(str(results))
messages = [{"role": "user", "content": f"""After the colon is a set of text with information about dogs, then a question about the given text. Please answer the question based off the text, and do not talk about the documentation:
text - {results}
question - {question}
Respond in a friendly manner; you are an informational about dogs."""}]
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
encoded_inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=True)
inputs, attention_mask = encoded_inputs["input_ids"].to(device), encoded_inputs["attention_mask"].to(device)
outputs = model.generate(inputs, attention_mask=attention_mask, max_new_tokens=150, temperature=0.4, top_p=0.6, do_sample=True)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cleaned_output = remove_unwanted_parts(output_text)
return cleaned_output
# Create Gradio interface
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="Dog Breed Q&A", description="Ask a question about your dog's breed! From over 70 different breeds. You can find the full list under this space's files 'Dog_List'. All done on a CPU!")
# Launch the interface
iface.launch()