File size: 3,674 Bytes
c2cecc5
 
3280898
8c5f27f
 
c2cecc5
 
415b325
c2cecc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from flask import Flask, request, jsonify
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from model import model_download
model_download()

# Initialize the Llama model with chat format set to "llama-2"
llm = Llama(model_path="./llama-2-7b-chat.Q2_K.gguf", chat_format="llama-2")

# Define the system prompt
system_prompt = (
    "I am an Indian law chatbot designed to provide legal support to marginalized communities. "
    "This model was fine-tuned by Sathish and his team members at the University College of Engineering Dindigul. "
    "The model has been trained on various legal topics. "
    "Feel free to ask questions."
)

# Initialize the conversation history list with the system prompt
conversation_history = [{"role": "system", "content": system_prompt}]

# Create a Flask application
app = Flask(__name__)

# Define the model function
def model(query):
    global conversation_history  # Declare global to update history

    # Add the user's query to the conversation history
    conversation_history.append({"role": "user", "content": query})

    # Calculate the total number of tokens in the conversation history
    # (You may need to modify this part to calculate the token count accurately based on your tokenizer)
    total_tokens = sum(len(message["content"].split()) for message in conversation_history)

    # If the total number of tokens exceeds the model's context window, trim the history
    # You may need to adjust the 512 value based on your model's actual context window size
    context_window_size = 512
    while total_tokens > context_window_size:
        # Remove the oldest messages from the conversation history
        conversation_history.pop(0)
        # Recalculate the total number of tokens
        total_tokens = sum(len(message["content"].split()) for message in conversation_history)

    # Generate chat completion with the conversation history
    response = llm.create_chat_completion(messages=conversation_history, max_tokens=75)
    
    # Extract the assistant's response from the completion dictionary
    if response and 'choices' in response and response['choices']:
        assistant_response = response['choices'][0]['message']['content']
        assistant_response = assistant_response.strip()
        
        # Add the assistant's response to the conversation history
        conversation_history.append({"role": "assistant", "content": assistant_response})

        # Print the assistant's response
        print("Assistant response:", assistant_response)
        
        # Return the assistant's response
        return assistant_response
    else:
        print("Error: Invalid response structure.")
        return None


# Define the endpoint for the API
@app.route("/chat", methods=["GET"])
def chat_endpoint():
    # Get the query parameter from the request
    query = request.args.get("query")
    
    # Check if the "refresh" parameter is set to "true"
    refresh = request.args.get("refresh")
    if refresh and refresh.lower() == "true":
        # Clear the conversation history
        global conversation_history
        conversation_history = [{"role": "system", "content": system_prompt}]
        return jsonify({"response": "Conversation history cleared."})
    
    # If there is no query, return an error message
    if not query:
        return jsonify({"error": "Query parameter is required."}), 400
    
    # Call the model function with the query
    response = model(query)
    
    # Return the assistant's response as JSON
    return jsonify({"response": response})

# Run the Flask app
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000)