Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
from llama_cpp import Llama | |
from huggingface_hub import hf_hub_download | |
from model import model_download | |
model_download() | |
# Initialize the Llama model with chat format set to "llama-2" | |
llm = Llama(model_path="./llama-2-7b-chat.Q2_K.gguf", chat_format="llama-2") | |
# Define the system prompt | |
system_prompt = ( | |
"I am an Indian law chatbot designed to provide legal support to marginalized communities. " | |
"This model was fine-tuned by Sathish and his team members at the University College of Engineering Dindigul. " | |
"The model has been trained on various legal topics. " | |
"Feel free to ask questions." | |
) | |
# Initialize the conversation history list with the system prompt | |
conversation_history = [{"role": "system", "content": system_prompt}] | |
# Create a Flask application | |
app = Flask(__name__) | |
# Define the model function | |
def model(query): | |
global conversation_history # Declare global to update history | |
# Add the user's query to the conversation history | |
conversation_history.append({"role": "user", "content": query}) | |
# Calculate the total number of tokens in the conversation history | |
# (You may need to modify this part to calculate the token count accurately based on your tokenizer) | |
total_tokens = sum(len(message["content"].split()) for message in conversation_history) | |
# If the total number of tokens exceeds the model's context window, trim the history | |
# You may need to adjust the 512 value based on your model's actual context window size | |
context_window_size = 512 | |
while total_tokens > context_window_size: | |
# Remove the oldest messages from the conversation history | |
conversation_history.pop(0) | |
# Recalculate the total number of tokens | |
total_tokens = sum(len(message["content"].split()) for message in conversation_history) | |
# Generate chat completion with the conversation history | |
response = llm.create_chat_completion(messages=conversation_history, max_tokens=75) | |
# Extract the assistant's response from the completion dictionary | |
if response and 'choices' in response and response['choices']: | |
assistant_response = response['choices'][0]['message']['content'] | |
assistant_response = assistant_response.strip() | |
# Add the assistant's response to the conversation history | |
conversation_history.append({"role": "assistant", "content": assistant_response}) | |
# Print the assistant's response | |
print("Assistant response:", assistant_response) | |
# Return the assistant's response | |
return assistant_response | |
else: | |
print("Error: Invalid response structure.") | |
return None | |
# Define the endpoint for the API | |
def chat_endpoint(): | |
# Get the query parameter from the request | |
query = request.args.get("query") | |
# Check if the "refresh" parameter is set to "true" | |
refresh = request.args.get("refresh") | |
if refresh and refresh.lower() == "true": | |
# Clear the conversation history | |
global conversation_history | |
conversation_history = [{"role": "system", "content": system_prompt}] | |
return jsonify({"response": "Conversation history cleared."}) | |
# If there is no query, return an error message | |
if not query: | |
return jsonify({"error": "Query parameter is required."}), 400 | |
# Call the model function with the query | |
response = model(query) | |
# Return the assistant's response as JSON | |
return jsonify({"response": response}) | |
# Run the Flask app | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=5000) | |