File size: 2,407 Bytes
7416d8a
5c0e14a
7416d8a
b35606a
7416d8a
 
5c0e14a
b35606a
5c0e14a
 
7416d8a
08eb742
b35606a
 
08eb742
5c0e14a
7416d8a
b35606a
5c0e14a
 
7416d8a
 
 
 
 
 
 
 
 
5c0e14a
 
7416d8a
 
5c0e14a
7416d8a
 
 
17fba42
7416d8a
 
 
b5fc4fe
 
 
 
 
 
b35606a
7416d8a
b5fc4fe
7416d8a
 
 
 
 
 
 
 
 
 
 
b35606a
7416d8a
 
 
 
 
 
 
 
b35606a
 
5c0e14a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import gradio as gr
import random
from huggingface_hub import InferenceClient

API_URL = "https://api-inference.huggingface.co/models/"

client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")

def format_prompt(message, history):
    prompt = "You're a helpful assistant."
    for user_prompt, bot_response in history:
        prompt += f" [INST] {user_prompt} [/INST] {bot_response}</s> "
    prompt += f" [INST] {message} [/INST]"
    return prompt

def generate(prompt, history, temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature) if temperature > 0 else 0.01
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=random.randint(0, 10**7),
    )

    formatted_prompt = format_prompt(prompt, history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output

def load_database():
    try:
        with open("database.json", "r", encoding="utf-8") as f:
            data = json.load(f)
            if not isinstance(data, list):
                raise ValueError("Invalid data format")
            return data
    except (FileNotFoundError, json.JSONDecodeError, ValueError):
        print("Error loading database: File not found, invalid format, or empty. Creating an empty database.")
        return []


def save_database(data):
    try:
        with open("database.json", "w", encoding="utf-8") as f:
            json.dump(data, f, indent=4)
    except (IOError, json.JSONEncodeError):
        print("Error saving database: Encountered an issue while saving.")

def chat_interface(message):
    database = load_database()

    if (message, None) not in database:
        response = next(generate(message, history=[]))
        database.append((message, response))
        save_database(database)
    else:
        _, stored_response = next(item for item in database if item[0] == message)
        response = stored_response

    return response

with gr.Interface(fn=chat_interface, inputs="textbox", outputs="textbox", title="Chat Interface") as iface:
    iface.launch()