Vitrous commited on
Commit
eb1b64a
·
verified ·
1 Parent(s): 48e03de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py CHANGED
@@ -5,6 +5,11 @@ import torch
5
  import optimum
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from fastapi.responses import HTMLResponse
 
 
 
 
 
8
 
9
  # Set environment variables for GPU usage and memory allocation
10
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
@@ -14,6 +19,7 @@ torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
14
  # Initialize FastAPI application
15
  app = FastAPI(root_path="/api/v1")
16
 
 
17
  # Load the model and tokenizer
18
  model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
19
 
@@ -74,6 +80,51 @@ def hermes_model():
74
 
75
  model, tokenizer = hermes_model()
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def chat_response(msg_prompt: str) -> dict:
78
  """
79
  Generates a response from the model given a prompt.
@@ -202,6 +253,46 @@ async def hermes_chat(request: Request):
202
  raise
203
  except Exception as e:
204
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  @app.post('/prompted_chat')
207
  async def prompted_chat(request: Request):
 
5
  import optimum
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from fastapi.responses import HTMLResponse
8
+ from datetime import datetime
9
+ import random
10
+ import string
11
+ from datasets import Dataset
12
+ import json
13
 
14
  # Set environment variables for GPU usage and memory allocation
15
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
 
19
  # Initialize FastAPI application
20
  app = FastAPI(root_path="/api/v1")
21
 
22
+
23
  # Load the model and tokenizer
24
  model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
25
 
 
80
 
81
  model, tokenizer = hermes_model()
82
 
83
+ def generate_id(length=5):
84
+ """
85
+ Generates a random alphanumeric ID.
86
+
87
+ Args:
88
+ length (int): The length of the ID.
89
+
90
+ Returns:
91
+ str: A random alphanumeric ID.
92
+ """
93
+ return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
94
+
95
+ def generate_thread_id():
96
+ """
97
+ Generates a unique thread ID for each conversation.
98
+
99
+ Returns:
100
+ str: A unique thread ID.
101
+ """
102
+ return generate_id()
103
+
104
+ def generate_message_id():
105
+ """
106
+ Generates a random alphanumeric message ID.
107
+
108
+ Returns:
109
+ str: A random alphanumeric message ID.
110
+ """
111
+ return generate_id()
112
+
113
+
114
+ def save_conversation(user_id, conversation):
115
+ hf_space_path="articko/conversations"
116
+ """
117
+ Save conversation history to disk.
118
+
119
+ Args:
120
+ user_id (str): The unique identifier for the user.
121
+ conversation (dict): The conversation data.
122
+ hf_space_path (str): The path to the Hugging Face Space.
123
+ """
124
+ with open(f'{hf_space_path}/conversations.jsonl', 'a') as file:
125
+ json.dump({user_id: conversation}, file)
126
+ file.write('\n')
127
+
128
  def chat_response(msg_prompt: str) -> dict:
129
  """
130
  Generates a response from the model given a prompt.
 
253
  raise
254
  except Exception as e:
255
  raise HTTPException(status_code=500, detail=str(e))
256
+
257
+ @app.post('/chat_thread/{user_id}')
258
+ async def chat_thread(request: Request, user_id: str):
259
+ """
260
+ Starts a new conversation thread with a provided prompt for a specific user.
261
+
262
+ Args:
263
+ request (Request): The HTTP request object containing the user prompt.
264
+ user_id (str): The unique identifier for the user.
265
+
266
+ Returns:
267
+ dict: The response generated by the model along with the user's conversation history.
268
+ """
269
+ try:
270
+ thread_id = generate_thread_id()
271
+
272
+ data = await request.json()
273
+ msg_prompt = data.get('msg_prompt')
274
+
275
+ if not msg_prompt:
276
+ raise HTTPException(status_code=400, detail="Prompt not provided")
277
+
278
+ # Generate response
279
+ response = chat_response(msg_prompt)
280
+
281
+ # Generate message ID
282
+ message_id = generate_message_id()
283
+
284
+ # Construct conversation entry
285
+ conversation_entry = {'thread_id': thread_id, 'message_id': message_id, 'user': msg_prompt, 'assistant': response}
286
+
287
+ # Save conversation history to disk
288
+ save_conversation(user_id, conversation_entry)
289
+
290
+ # Return response and thread ID
291
+ return {'response': conversation_entry}
292
+ except HTTPException as e:
293
+ raise e
294
+ except Exception as e:
295
+ raise HTTPException(status_code=500, detail=str(e))
296
 
297
  @app.post('/prompted_chat')
298
  async def prompted_chat(request: Request):