MatteoScript commited on
Commit
e5def25
·
verified ·
1 Parent(s): 652d01a

Delete chat_client.py

Browse files
Files changed (1) hide show
  1. chat_client.py +0 -86
chat_client.py DELETED
@@ -1,86 +0,0 @@
1
- from huggingface_hub import InferenceClient
2
- import os
3
- from dotenv import load_dotenv
4
- import random
5
- import json
6
- from openai import OpenAI
7
-
8
- load_dotenv()
9
- API_TOKEN = os.getenv('HF_TOKEN')
10
-
11
- def format_prompt(message, history):
12
- prompt = "<s>"
13
- for user_prompt, bot_response in history:
14
- prompt += f"[INST] {user_prompt} [/INST]"
15
- prompt += f" {bot_response}</s> "
16
- prompt += f"[INST] {message} [/INST]"
17
- return prompt
18
-
19
- def format_prompt_openai(system_prompt, message, history):
20
- messages = []
21
- if system_prompt != '':
22
- messages.append({"role": "system", "content": system_prompt})
23
- for user_prompt, bot_response in history:
24
- messages.append({"role": "user", "content": user_prompt})
25
- messages.append({"role": "assistant", "content": bot_response})
26
- messages.append({"role": "user", "content": message})
27
- return messages
28
-
29
- def chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty):
30
- client = InferenceClient(
31
- chat_client,
32
- token=API_TOKEN
33
- )
34
- temperature = float(temperature)
35
- if temperature < 1e-2:
36
- temperature = 1e-2
37
- top_p = float(top_p)
38
-
39
- generate_kwargs = dict(
40
- temperature=temperature,
41
- max_new_tokens=max_new_tokens,
42
- top_p=top_p,
43
- repetition_penalty=repetition_penalty,
44
- do_sample=True,
45
- seed=random.randint(0, 10**7),
46
- )
47
- formatted_prompt = format_prompt(prompt, history)
48
- print('***************************************************')
49
- print(formatted_prompt)
50
- print('***************************************************')
51
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
52
- return stream
53
-
54
- def chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai):
55
- try:
56
- prompt = prompt.replace('\n', '')
57
- json_data = json.loads(prompt)
58
- user_prompt = json_data["messages"][1]["content"]
59
- system_prompt = json_data["input"]["content"]
60
- system_style = json_data["input"]["style"]
61
- instructions = json_data["messages"][0]["content"]
62
- if instructions != '':
63
- system_prompt += '\n' + instructions
64
- if system_style != '':
65
- system_prompt += '\n' + system_style
66
- except:
67
- user_prompt = prompt
68
- system_prompt = ''
69
- messages = format_prompt_openai(system_prompt, user_prompt, history)
70
- print('***************************************************')
71
- print(messages)
72
- print('***************************************************')
73
- stream = client_openai.chat.completions.create(
74
- model=chat_client,
75
- stream=True,
76
- messages=messages,
77
- temperature=temperature,
78
- max_tokens=max_new_tokens,
79
- )
80
- return stream
81
-
82
- def chat(prompt, history, chat_client,temperature=0.9, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0, client_openai = None):
83
- if chat_client[:3] == 'gpt':
84
- return chat_openai(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty, client_openai)
85
- else:
86
- return chat_huggingface(prompt, history, chat_client, temperature, max_new_tokens, top_p, repetition_penalty)