AlexanderKazakov
small improvement for chunking; openai embeddings
eeafaaa
raw
history blame
6.92 kB
import gradio as gr
from typing import Any, Dict, Generator, List
# from huggingface_hub import InferenceClient
# from transformers import AutoTokenizer
from jinja2 import Environment, FileSystemLoader
from settings import *
from gradio_app.backend.ChatGptInteractor import *
# tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
# HF_TOKEN = None
# hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
def format_prompt(message: str, api_kind: str):
"""
Formats the given message using a chat template.
Args:
message (str): The user message to be formatted.
Returns:
str: Formatted message after applying the chat template.
"""
# Create a list of message dictionaries with role and content
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
if api_kind == "openai":
return messages
elif api_kind == "hf":
return tokenizer.apply_chat_template(messages, tokenize=False)
else:
raise ValueError("API is not supported")
def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
"""
Generate a sequence of tokens based on a given prompt and history using Mistral client.
Args:
prompt (str): The initial prompt for the text generation.
history (str): Context or history for the text generation.
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
Returns:
Generator[str, None, str]: A generator yielding chunks of generated text.
Returns a final string if an error occurs.
"""
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
top_p = float(top_p)
generate_kwargs = {
'temperature': temperature,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'do_sample': True,
'seed': 42,
}
formatted_prompt = format_prompt(prompt, "hf")
try:
stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
except Exception as e:
if "Too Many Requests" in str(e):
print("ERROR: Too many requests on Mistral client")
gr.Warning("Unfortunately Mistral is unable to process")
return "Unfortunately, I am not able to process your request now."
elif "Authorization header is invalid" in str(e):
print("Authetification error:", str(e))
gr.Warning("Authentication error: HF token was either not provided or incorrect")
return "Authentication error"
else:
print("Unhandled Exception:", str(e))
gr.Warning("Unfortunately Mistral is unable to process")
return "I do not know what happened, but I couldn't understand you."
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
context_template = env.get_template('context_template.j2')
start_system_message = context_template.render(documents=[])
def construct_openai_messages(context, history):
messages = [
{
"role": "system",
"content": start_system_message,
},
]
for q, a in history:
if len(a) == 0: # the last message
messages.append({
"role": "system",
"content": context,
})
messages.append({
"role": "user",
"content": q,
})
if len(a) != 0: # some of the previous LLM answers
messages.append({
"role": "assistant",
"content": a,
})
return messages
def generate_openai(messages):
cgi = ChatGptInteractor(model_name=LLM_NAME)
for part in cgi.chat_completion(messages, max_tokens=512, temperature=0, stream=True):
yield cgi.get_stream_text(part)
def _generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
"""
Generate a sequence of tokens based on a given prompt and history using Mistral client.
Args:
prompt (str): The initial prompt for the text generation.
history (str): Context or history for the text generation.
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
Returns:
Generator[str, None, str]: A generator yielding chunks of generated text.
Returns a final string if an error occurs.
"""
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
top_p = float(top_p)
generate_kwargs = {
'temperature': temperature,
'max_tokens': max_new_tokens,
'top_p': top_p,
'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
}
formatted_prompt = format_prompt(prompt, "openai")
try:
stream = openai.ChatCompletion.create(
model=LLM_NAME,
messages=formatted_prompt,
**generate_kwargs,
stream=True
)
output = ""
for chunk in stream:
output += chunk.choices[0].delta.get("content", "")
yield output
except Exception as e:
if "Too Many Requests" in str(e):
print("ERROR: Too many requests on OpenAI client")
gr.Warning("Unfortunately OpenAI is unable to process")
return "Unfortunately, I am not able to process your request now."
elif "You didn't provide an API key" in str(e):
print("Authetification error:", str(e))
gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
return "Authentication error"
else:
print("Unhandled Exception:", str(e))
gr.Warning("Unfortunately OpenAI is unable to process")
return "I do not know what happened, but I couldn't understand you."