Spaces:
Running
Running
# | |
# SPDX-FileCopyrightText: Hadad <[email protected]> | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
import json # Import JSON module for encoding and decoding JSON data | |
import uuid # Import UUID module to generate unique session identifiers | |
from typing import Any, List # Import typing annotations for type hinting | |
from config import model # Import model configuration dictionary from config module | |
from src.core.server import jarvis # Import the async function to interact with AI backend | |
from src.core.parameter import parameters # Import parameters (not used directly here but imported for completeness) | |
from src.core.session import session # Import session dictionary to store conversation histories | |
from src.tools.audio import AudioGeneration # Import AudioGeneration class to handle audio creation | |
from src.tools.image import ImageGeneration # Import ImageGeneration class to handle image creation | |
from src.tools.deep_search import SearchTools # Import SearchTools class for deep search functionality | |
import gradio as gr # Import Gradio library for UI and request handling | |
# Define an asynchronous function 'respond' to process user messages and generate AI responses | |
# This version uses the "messages" style for chat history, where history is a list of dicts with "role" and "content" keys, | |
# supporting content as strings, dicts with "path" keys, or Gradio components. | |
async def respond( | |
message, # Incoming user message, can be a string or a dictionary containing text and files | |
history: List[Any], # List containing conversation history as pairs of user and assistant messages (tuples style) | |
model_label, # Label/key to select the AI model from the available models | |
temperature, # Sampling temperature controlling randomness of AI response generation | |
top_k, # Number of highest probability tokens to keep for sampling | |
min_p, # Minimum probability threshold for token sampling | |
top_p, # Cumulative probability threshold for nucleus sampling | |
repetition_penalty, # Penalty factor to reduce repetitive tokens in generated text | |
thinking, # Boolean flag indicating if AI should operate in "thinking" mode | |
image_gen, # Boolean flag to enable image generation commands | |
audio_gen, # Boolean flag to enable audio generation commands | |
search_gen, # Boolean flag to enable deep search commands | |
request: gr.Request # Gradio request object to access session information such as session hash | |
): | |
# Select the AI model based on the provided label, if label not found, fallback to the first model in the config | |
selected_model = model.get(model_label, list(model.values())[0]) | |
# Instantiate SearchTools to enable deep search capabilities if requested | |
search_tools = SearchTools() | |
# Retrieve session ID from the Gradio request's session hash, generate a new UUID if none exists | |
session_id = request.session_hash or str(uuid.uuid4()) | |
# Initialize an empty conversation history for this session if it does not already exist | |
if session_id not in session: | |
session[session_id] = [] | |
# Determine the mode string based on the 'thinking' flag, affects AI response generation behavior | |
mode = "/think" if thinking else "/no_think" | |
# Initialize variables for user input text and any attached files | |
input = "" | |
files = None | |
# Check if the incoming message is a dictionary (which may contain text and files) | |
if isinstance(message, dict): | |
# Extract the text content from the message dictionary, default to empty string if missing | |
input = message.get("text", "") | |
# Extract the first file from the files list if present, otherwise, set files to None | |
files = message.get("files")[0] if message.get("files") else None | |
else: | |
# If the message is a simple string, assign it directly to input | |
input = message | |
# Strip leading and trailing whitespace from the input for clean processing | |
stripped_input = input.strip() | |
# Convert the stripped input to lowercase for case-insensitive command detection | |
lowered_input = stripped_input.lower() | |
# If the input is empty after stripping, yield an empty list and exit the function early | |
if not stripped_input: | |
yield [] | |
return | |
# If the input is exactly one of the command keywords without parameters, yield empty and exit early | |
if lowered_input in ["/audio", "/image", "/dp"]: | |
yield [] | |
return | |
# Prepare a new conversation history list formatted with roles and content for AI model consumption | |
# Here we convert the old "tuples" style history (list of [user_msg, assistant_msg]) into "messages" style: | |
# a flat list of dicts with "role" and "content" keys. | |
new_history = [] | |
for entry in history: | |
# Ensure the entry is a list with exactly two elements: user message and assistant message | |
if isinstance(entry, list) and len(entry) == 2: | |
user_msg, assistant_msg = entry | |
# Append the user message with role 'user' to the new history if not None | |
if user_msg is not None: | |
new_history.append({"role": "user", "content": user_msg}) | |
# Append the assistant message with role 'assistant' if it exists and is not None | |
if assistant_msg is not None: | |
new_history.append({"role": "assistant", "content": assistant_msg}) | |
# Update the global session dictionary with the newly formatted conversation history for this session | |
session[session_id] = new_history | |
# Handle audio generation command if enabled and input starts with '/audio' | |
if audio_gen and lowered_input.startswith("/audio"): | |
# Extract the audio instruction text after the '/audio' command prefix and strip whitespace | |
audio_instruction = input[6:].strip() | |
# If no instruction text is provided, yield empty and exit early | |
if not audio_instruction: | |
yield [] | |
return | |
try: | |
# Asynchronously create audio content based on the instruction using AudioGeneration class | |
audio = await AudioGeneration.create_audio(audio_instruction) | |
# Serialize the audio data and instruction into a JSON formatted string | |
audio_generation_content = json.dumps({ | |
"audio": audio, | |
"audio_instruction": audio_instruction | |
}) | |
# Construct the conversation history including the audio generation result and detailed instructions | |
audio_generation_result = ( | |
new_history | |
+ [ | |
{ | |
"role": "system", | |
"content": ( | |
f"Audio generation result:\n\n{audio_generation_content}\n\n\n" | |
"Show the audio using the following HTML audio tag format, where '{audio_link}' is the URL of the generated audio:\n\n" | |
"<audio controls src='{audio_link}' style='width:100%; max-width:100%;'></audio>\n\n" | |
"Please replace '{audio_link}' with the actual audio URL provided in the context.\n\n" | |
"Then, describe the generated audio based on the above information.\n\n\n" | |
"Use the same language as the previous user input or user request.\n" | |
"For example, if the previous user input or user request is in Indonesian, explain in Indonesian.\n" | |
"If it is in English, explain in English. This also applies to other languages.\n\n\n" | |
) | |
} | |
] | |
) | |
# Use async generator to get descriptive text about the generated audio | |
async for audio_description in jarvis( | |
session_id=session_id, | |
model=selected_model, | |
history=audio_generation_result, | |
user_message=input, | |
mode="/no_think", # Use no_think mode to avoid extra processing | |
temperature=0.7, # Fixed temperature for audio description generation | |
top_k=20, # Limit token sampling to top 20 tokens | |
min_p=0, # Minimum probability threshold | |
top_p=0.8, # Nucleus sampling threshold | |
repetition_penalty=1.0 # No repetition penalty for this step | |
): | |
# Yield the audio description wrapped in a tool role for UI display | |
yield [{"role": "tool", "content": f'{audio_description}'}] | |
return | |
except Exception: | |
# If audio generation fails, yield an error message and exit | |
yield [{"role": "tool", "content": "Audio generation failed. Please wait 15 seconds before trying again."}] | |
return | |
# Handle image generation command if enabled and input starts with '/image' | |
if image_gen and lowered_input.startswith("/image"): | |
# Extract the image generation instruction after the '/image' command prefix and strip whitespace | |
generate_image_instruction = input[6:].strip() | |
# If no instruction text is provided, yield empty and exit early | |
if not generate_image_instruction: | |
yield [] | |
return | |
try: | |
# Asynchronously create image content based on the instruction using ImageGeneration class | |
image = await ImageGeneration.create_image(generate_image_instruction) | |
# Serialize the image data and instruction into a JSON formatted string | |
image_generation_content = json.dumps({ | |
"image": image, | |
"generate_image_instruction": generate_image_instruction | |
}) | |
# Construct the conversation history including the image generation result and detailed instructions | |
image_generation_result = ( | |
new_history | |
+ [ | |
{ | |
"role": "system", | |
"content": ( | |
f"Image generation result:\n\n{image_generation_content}\n\n\n" | |
"Show the generated image using the following markdown syntax format, where '{image_link}' is the URL of the image:\n\n" | |
"\n\n" | |
"Please replace '{image_link}' with the actual image URL provided in the context.\n\n" | |
"Then, describe the generated image based on the above information.\n\n\n" | |
"Use the same language as the previous user input or user request.\n" | |
"For example, if the previous user input or user request is in Indonesian, explain in Indonesian.\n" | |
"If it is in English, explain in English. This also applies to other languages.\n\n\n" | |
) | |
} | |
] | |
) | |
# Use async generator to get descriptive text about the generated image | |
async for image_description in jarvis( | |
session_id=session_id, | |
model=selected_model, | |
history=image_generation_result, | |
user_message=input, | |
mode="/no_think", # Use no_think mode to avoid extra processing | |
temperature=0.7, # Fixed temperature for image description generation | |
top_k=20, # Limit token sampling to top 20 tokens | |
min_p=0, # Minimum probability threshold | |
top_p=0.8, # Nucleus sampling threshold | |
repetition_penalty=1.0 # No repetition penalty for this step | |
): | |
# Yield the image description wrapped in a tool role for UI display | |
yield [{"role": "tool", "content": f"{image_description}"}] | |
return | |
except Exception: | |
# If image generation fails, yield an error message and exit | |
yield [{"role": "tool", "content": "Image generation failed. Please wait 15 seconds before trying again."}] | |
return | |
# Handle deep search command if enabled and input starts with '/dp' | |
if search_gen and lowered_input.startswith("/dp"): | |
# Extract the search query after the '/dp' command prefix and strip whitespace | |
search_query = input[3:].strip() | |
# If no search query is provided, yield empty and exit early | |
if not search_query: | |
yield [] | |
return | |
try: | |
# Perform an asynchronous deep search using SearchTools with the given query | |
search_results = await search_tools.search(search_query) | |
# Serialize the search query and results (limited to first 5000 characters) into JSON string | |
search_content = json.dumps({ | |
"query": search_query, | |
"search_results": search_results[:5000] | |
}) | |
# Construct conversation history including deep search results and detailed instructions for summarization | |
search_instructions = ( | |
new_history | |
+ [ | |
{ | |
"role": "system", | |
"content": ( | |
f"Deep search results for query: '{search_query}':\n\n{search_content}\n\n\n" | |
"Please analyze these search results and provide a comprehensive summary of the information.\n" | |
"Identify the most relevant information related to the query.\n" | |
"Format your response in a clear, structured way with appropriate headings and bullet points if needed.\n" | |
"If the search results don't provide sufficient information, acknowledge this limitation.\n" | |
"Please provide links or URLs from each of your search results.\n\n" | |
"Use the same language as the previous user input or user request.\n" | |
"For example, if the previous user input or user request is in Indonesian, explain in Indonesian.\n" | |
"If it is in English, explain in English. This also applies to other languages.\n\n\n" | |
) | |
} | |
] | |
) | |
# Use async generator to process the deep search results and generate a summary response | |
async for search_response in jarvis( | |
session_id=session_id, | |
model=selected_model, | |
history=search_instructions, | |
user_message=input, | |
mode=mode, # Use the mode determined by the thinking flag | |
temperature=temperature, | |
top_k=top_k, | |
min_p=min_p, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty | |
): | |
# Yield the search summary wrapped in a tool role for UI display | |
yield [{"role": "tool", "content": f"{search_response}"}] | |
return | |
except Exception as e: | |
# If deep search fails, yield an error message and exit | |
yield [{"role": "tool", "content": "Search failed, please try again later."}] | |
return | |
# For all other inputs that do not match special commands, use the jarvis function to generate a response | |
async for response in jarvis( | |
session_id=session_id, | |
model=selected_model, | |
history=new_history, # Pass the conversation history in "messages" style format | |
user_message=input, | |
mode=mode, # Use the mode determined by the thinking flag | |
files=files, # Pass any attached files along with the message | |
temperature=temperature, | |
top_k=top_k, | |
min_p=min_p, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty | |
): | |
# Yield each chunk of the response as it is generated | |
yield response |