llm-chatbot / app.py
lightmate's picture
Update app.py
9f09252 verified
raw
history blame
8.58 kB
# app.py
import os
from pathlib import Path
import torch
from threading import Event, Thread
from typing import List, Tuple
# Importing necessary packages
from transformers import AutoConfig, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from langchain_community.tools import DuckDuckGoSearchRun
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
from gradio_helper import make_demo # UI logic import
from llm_config import SUPPORTED_LLM_MODELS
# Model configuration setup
max_new_tokens = 256
model_language_value = "English"
model_id_value = 'qwen2.5-0.5b-instruct'
prepare_int4_model_value = True
enable_awq_value = False
device_value = 'CPU'
model_to_run_value = 'INT4'
pt_model_id = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]["model_id"]
pt_model_name = model_id_value.split("-")[0]
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
int4_weights = int4_model_dir / "openvino_model.bin"
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
model_name = model_configuration["model_id"]
start_message = model_configuration["start_message"]
history_template = model_configuration.get("history_template")
has_chat_template = model_configuration.get("has_chat_template", history_template is None)
current_message_template = model_configuration.get("current_message_template")
stop_tokens = model_configuration.get("stop_tokens")
tokenizer_kwargs = model_configuration.get("tokenizer_kwargs", {})
# Model loading
core = ov.Core()
ov_config = {
hints.performance_mode(): hints.PerformanceMode.LATENCY,
streams.num(): "1",
props.cache_dir(): ""
}
tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
int4_model_dir,
device=device_value,
ov_config=ov_config,
config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True),
trust_remote_code=True,
)
# Define stopping criteria for specific token sequences
class StopOnTokens(StoppingCriteria):
def __init__(self, token_ids):
self.token_ids = token_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(input_ids[0][-1] == stop_id for stop_id in self.token_ids)
if stop_tokens is not None:
if isinstance(stop_tokens[0], str):
stop_tokens = tok.convert_tokens_to_ids(stop_tokens)
stop_tokens = [StopOnTokens(stop_tokens)]
# Helper function for partial text update
def default_partial_text_processor(partial_text: str, new_text: str) -> str:
return partial_text + new_text
text_processor = model_configuration.get("partial_text_processor", default_partial_text_processor)
# Convert conversation history to tokens based on model template
def convert_history_to_token(history: List[Tuple[str, str]]):
if pt_model_name == "baichuan2":
system_tokens = tok.encode(start_message)
history_tokens = []
for old_query, response in history[:-1]:
round_tokens = [195] + tok.encode(old_query) + [196] + tok.encode(response)
history_tokens = round_tokens + history_tokens
input_tokens = system_tokens + history_tokens + [195] + tok.encode(history[-1][0]) + [196]
input_token = torch.LongTensor([input_tokens])
elif history_template is None or has_chat_template:
messages = [{"role": "system", "content": start_message}]
for idx, (user_msg, model_msg) in enumerate(history):
if idx == len(history) - 1 and not model_msg:
messages.append({"role": "user", "content": user_msg})
break
if user_msg:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
input_token = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt")
else:
text = start_message + "".join(
[history_template.format(num=round, user=item[0], assistant=item[1]) for round, item in enumerate(history[:-1])]
)
text += current_message_template.format(num=len(history) + 1, user=history[-1][0], assistant=history[-1][1])
input_token = tok(text, return_tensors="pt", **tokenizer_kwargs).input_ids
return input_token
# Initialize search tool
search = DuckDuckGoSearchRun()
# Determine if a search is needed based on the query
def should_use_search(query: str) -> bool:
search_keywords = ["latest", "news", "update", "which", "who", "what", "when", "why", "how", "recent", "current",
"announcement", "bulletin", "report", "brief", "insight", "disclosure", "update",
"release", "memo", "headline", "current", "ongoing", "fresh", "upcoming", "immediate",
"recently", "new", "now", "in-progress", "inquiry", "query", "ask", "investigate",
"explore", "seek", "clarify", "confirm", "discover", "learn", "describe", "define",
"illustrate", "outline", "interpret", "expound", "detail", "summarize", "elucidate",
"break down", "outcome", "effect", "consequence", "finding", "achievement", "conclusion",
"product", "performance", "resolution"
]
return any(keyword in query.lower() for keyword in search_keywords)
# Construct the prompt with optional search context
def construct_model_prompt(user_query: str, search_context: str, history: List[Tuple[str, str]]) -> str:
instructions = "Use the information below if relevant to provide an accurate and concise answer. If no information is available, rely on your general knowledge."
prompt = f"{instructions}\n\n{search_context if search_context else ''}\n\n{user_query} ?\n\n"
return prompt
# Fetch search results for a query
def fetch_search_results(query: str) -> str:
search_results = search.invoke(query)
print("Search results:", search_results) # Optional: Debugging output
return f"Relevant and recent information:\n{search_results}"
# Main chatbot function
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
user_query = history[-1][0]
search_context = fetch_search_results(user_query) if should_use_search(user_query) else ""
prompt = construct_model_prompt(user_query, search_context, history)
input_ids = tok(prompt, return_tensors="pt", truncation=True, max_length=2500).input_ids if search_context else convert_history_to_token(history)
# Limit input length to avoid exceeding token limit
if input_ids.shape[1] > 2000:
history = [history[-1]]
# Configure response streaming
streamer = TextIteratorStreamer(tok, timeout=4600.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": temperature > 0.0,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList(stop_tokens) if stop_tokens is not None else None,
}
# Signal completion
stream_complete = Event()
def generate_and_signal_complete():
try:
ov_model.generate(**generate_kwargs)
except RuntimeError as e:
# Check if the error message indicates the request was canceled
if "Infer Request was canceled" in str(e):
print("Generation request was canceled.")
else:
# If it's a different RuntimeError, re-raise it
raise e
finally:
# Signal completion of the stream
stream_complete.set()
t1 = Thread(target=generate_and_signal_complete)
t1.start()
partial_text = ""
for new_text in streamer:
partial_text = text_processor(partial_text, new_text)
history[-1] = (user_query, partial_text)
yield history
def request_cancel():
ov_model.request.cancel()
# Gradio setup and launch
demo = make_demo(run_fn=bot, title=f"OpenVINO Search & Reasoning Chatbot", language=model_language_value)
if __name__ == "__main__":
demo.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)