chat-with-SFF / app.py
ccm's picture
Update app.py
21ff47a verified
raw
history blame
4.68 kB
import gradio # Interface handling
import spaces # For GPU
import langchain_community.vectorstores # Vectorstore for publications
import langchain_huggingface # Embeddings
import transformers
# The number of publications to retrieve for the prompt
PUBLICATIONS_TO_RETRIEVE = 5
# The template for the RAG prompt
RAG_TEMPLATE = """You are an AI assistant who enjoys helping users learn about research.
Answer the USER_QUERY on additive manufacturing research using the RESEARCH_EXCERPTS.
Provide a concise ANSWER based on these excerpts. Avoid listing references.
===== RESEARCH_EXCERPTS =====
{research_excerpts}
===== USER_QUERY =====
{query}
===== ANSWER =====
"""
# Load vectorstore of SFF publications
publication_vectorstore = langchain_community.vectorstores.FAISS.load_local(
folder_path="publication_vectorstore",
embeddings=langchain_huggingface.HuggingFaceEmbeddings(
model_name="all-MiniLM-L12-v2",
model_kwargs={"device": "cuda"},
encode_kwargs={"normalize_embeddings": False},
),
allow_dangerous_deserialization=True,
)
#
# # Create the callable LLM
# llm = transformers.pipeline(
# task="text-generation",
# model="Qwen/Qwen2.5-7B-Instruct-AWQ",
# device="cuda",
# )
def preprocess(query: str) -> str:
"""
Generates a prompt based on the top k documents matching the query.
Args:
query (str): The user's query.
Returns:
str: The formatted prompt containing research excerpts and the user's query.
"""
# Search for the top k documents matching the query
documents = publication_vectorstore.search(
query, k=PUBLICATIONS_TO_RETRIEVE, search_type="similarity"
)
# Extract the page content from the documents
research_excerpts = [f'"... {doc.page_content}..."' for doc in documents]
# Format the prompt with the research excerpts and the user's query
prompt = RAG_TEMPLATE.format(
research_excerpts="\n\n".join(research_excerpts), query=query
)
return prompt
import threading
@spaces.GPU
def reply(message: str, history: list[str]) -> str:
"""
Generates a response to the user’s message.
Args:
message (str): The user's message or query.
history (list[str]): The conversation history.
Returns:
str: The generated response from the language model.
"""
tok = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-AWQ")
model = transformers.AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-7B-Instruct-AWQ"
)
inputs = tok([preprocess(message)], return_tensors="pt")
streamer = transformers.TextIteratorStreamer(tok)
generation_kwargs = dict(
inputs, streamer=streamer, max_new_tokens=512, return_full_text=False
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
# yield llm(
# preprocess(message),
# max_new_tokens=512,
# return_full_text=False,
# streamer=transformers.TextIteratorStreamer(
# transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct-AWQ")
# ),
# )[0]["generated_text"]
# Example Queries for Interface
EXAMPLE_QUERIES = [
{"text": "What is multi-material 3D printing?"},
{"text": "How is additive manufacturing being applied in aerospace?"},
{"text": "Tell me about innovations in metal 3D printing techniques."},
{"text": "What are some sustainable materials for 3D printing?"},
{
"text": "What are the biggest challenges with support structures in additive manufacturing?"
},
{"text": "How is 3D printing impacting the medical field?"},
{
"text": "What are some common applications of additive manufacturing in industry?"
},
{"text": "What are the benefits and limitations of using polymers in 3D printing?"},
{"text": "Tell me about the environmental impacts of additive manufacturing."},
{"text": "What are the primary limitations of current 3D printing technologies?"},
{"text": "How are researchers improving the speed of 3D printing processes?"},
{
"text": "What are the best practices for managing post-processing in additive manufacturing?"
},
]
# Run the Gradio Interface
gradio.ChatInterface(
reply,
examples=EXAMPLE_QUERIES,
cache_examples=False,
chatbot=gradio.Chatbot(
show_label=False,
show_share_button=False,
show_copy_button=False,
bubble_full_width=False,
),
).launch(debug=True)