acecalisto3 commited on
Commit
1d256a1
·
verified ·
1 Parent(s): e1339b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -34
app.py CHANGED
@@ -108,10 +108,10 @@ def run_gpt(
108
  safe_search=safe_search,
109
  ) + prompt_template.format(**prompt_kwargs)
110
  if VERBOSE:
111
- logging.info(LOG_PROMPT.format(content)) # Log the prompt
112
  resp = client.text_generation(content, max_new_tokens=max_new_tokens, stop_sequences=stop_tokens, temperature=0.7, top_p=0.8, repetition_penalty=1.5)
113
  if VERBOSE:
114
- logging.info(LOG_RESPONSE.format([resp])) # Log the response
115
  return resp
116
 
117
  def generate(
@@ -121,15 +121,14 @@ def generate(
121
  logging.info(f"Seed: {seed}") # Log the seed
122
 
123
  # Set the agent prompt based on agent_name
 
124
  if agent_name == "WEB_DEV":
125
- agent = "You are a helpful AI assistant. You are a web developer."
126
  elif agent_name == "AI_SYSTEM_PROMPT":
127
- agent = "You are a helpful AI assistant. You are an AI system."
128
  elif agent_name == "PYTHON_CODE_DEV":
129
- agent = "You are a helpful AI assistant. You are a Python code developer."
130
- else:
131
- agent = "You are a helpful AI assistant."
132
-
133
  system_prompt = f"{agent} {sys_prompt}".strip()
134
 
135
  temperature = max(float(temperature), 1e-2)
@@ -142,9 +141,10 @@ def generate(
142
  formatted_prompt = format_prompt(formatted_prompt, history, max_history_turns=5) # Truncated history
143
  logging.info(f"Formatted Prompt: {formatted_prompt}")
144
 
145
- client = InferenceClient(model) if model != "mistralai/Mixtral-8x7B-Instruct-v0.1" else InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
 
146
 
147
- stream = client.text_generation(
148
  formatted_prompt,
149
  temperature=temperature,
150
  max_new_tokens=max_new_tokens,
@@ -329,27 +329,31 @@ def format_prompt(message, history, max_history_turns=5):
329
  prompt += f" {bot_response}</s> "
330
  prompt += f"[INST] {message} [/INST]"
331
  return prompt
 
332
  agents =[
333
  "WEB_DEV",
334
  "AI_SYSTEM_PROMPT",
335
  "PYTHON_CODE_DEV"
336
  ]
 
337
  def generate(
338
  prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0, model="mistralai/Mixtral-8x7B-Instruct-v0.1"
339
  ):
340
  seed = random.randint(1,1111111111111111)
 
341
 
342
- # Correct the line:
 
343
  if agent_name == "WEB_DEV":
344
- agent = "You are a helpful AI assistant. You are a web developer."
345
- if agent_name == "AI_SYSTEM_PROMPT":
346
- agent = "You are a helpful AI assistant. You are an AI system."
347
- if agent_name == "PYTHON_CODE_DEV":
348
- agent = "You are a helpful AI assistant. You are a Python code developer."
349
- system_prompt = agent
350
- temperature = float(temperature)
351
- if temperature < 1e-2:
352
- temperature = 1e-2
353
  top_p = float(top_p)
354
 
355
  # Add the system prompt to the beginning of the prompt
@@ -358,14 +362,28 @@ def generate(
358
  # Use 'prompt' here instead of 'message'
359
  formatted_prompt = format_prompt(formatted_prompt, history, max_history_turns=5) # Truncated history
360
  logging.info(f"Formatted Prompt: {formatted_prompt}")
361
- stream = client.text_generation(formatted_prompt, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, stream=True, details=True, return_full_text=False)
362
- resp = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  for response in stream:
364
  resp += response.token.text
365
  yield resp # This allows for streaming the response
366
 
367
  if VERBOSE:
368
- logging.info(LOG_RESPONSE.format(resp)) # Pass resp to format
 
369
 
370
  def generate_text_chunked(input_text, model, generation_parameters, max_tokens_to_generate):
371
  """Generates text in chunks to avoid token limit errors."""
@@ -387,17 +405,7 @@ def generate_text_chunked(input_text, model, generation_parameters, max_tokens_t
387
 
388
  return ''.join(generated_text)
389
 
390
- formatted_prompt = format_prompt(prompt, history, max_history_turns=5) # Truncated history
391
- logging.info(f"Formatted Prompt: {formatted_prompt}")
392
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
393
- output = ""
394
-
395
- for response in stream:
396
- output += response.token.text
397
- yield output
398
- return output
399
-
400
-
401
  additional_inputs=[
402
  gr.Dropdown(
403
  label="Agents",
 
108
  safe_search=safe_search,
109
  ) + prompt_template.format(**prompt_kwargs)
110
  if VERBOSE:
111
+ logging.info(LOG_PROMPT.format(content=content)) # Log the prompt
112
  resp = client.text_generation(content, max_new_tokens=max_new_tokens, stop_sequences=stop_tokens, temperature=0.7, top_p=0.8, repetition_penalty=1.5)
113
  if VERBOSE:
114
+ logging.info(LOG_RESPONSE.format(resp=resp)) # Log the response
115
  return resp
116
 
117
  def generate(
 
121
  logging.info(f"Seed: {seed}") # Log the seed
122
 
123
  # Set the agent prompt based on agent_name
124
+ agent = "You are a helpful AI assistant."
125
  if agent_name == "WEB_DEV":
126
+ agent += " You are a web developer."
127
  elif agent_name == "AI_SYSTEM_PROMPT":
128
+ agent += " You are an AI system."
129
  elif agent_name == "PYTHON_CODE_DEV":
130
+ agent += " You are a Python code developer."
131
+
 
 
132
  system_prompt = f"{agent} {sys_prompt}".strip()
133
 
134
  temperature = max(float(temperature), 1e-2)
 
141
  formatted_prompt = format_prompt(formatted_prompt, history, max_history_turns=5) # Truncated history
142
  logging.info(f"Formatted Prompt: {formatted_prompt}")
143
 
144
+ # Conditionally create client
145
+ this_client = InferenceClient(model) if model != "mistralai/Mixtral-8x7B-Instruct-v0.1" else client
146
 
147
+ stream = this_client.text_generation(
148
  formatted_prompt,
149
  temperature=temperature,
150
  max_new_tokens=max_new_tokens,
 
329
  prompt += f" {bot_response}</s> "
330
  prompt += f"[INST] {message} [/INST]"
331
  return prompt
332
+
333
  agents =[
334
  "WEB_DEV",
335
  "AI_SYSTEM_PROMPT",
336
  "PYTHON_CODE_DEV"
337
  ]
338
+
339
  def generate(
340
  prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0, model="mistralai/Mixtral-8x7B-Instruct-v0.1"
341
  ):
342
  seed = random.randint(1,1111111111111111)
343
+ logging.info(f"Seed: {seed}") # Log the seed
344
 
345
+ # Set the agent prompt based on agent_name
346
+ agent = "You are a helpful AI assistant."
347
  if agent_name == "WEB_DEV":
348
+ agent += " You are a web developer."
349
+ elif agent_name == "AI_SYSTEM_PROMPT":
350
+ agent += " You are an AI system."
351
+ elif agent_name == "PYTHON_CODE_DEV":
352
+ agent += " You are a Python code developer."
353
+
354
+ system_prompt = f"{agent} {sys_prompt}".strip()
355
+
356
+ temperature = max(float(temperature), 1e-2)
357
  top_p = float(top_p)
358
 
359
  # Add the system prompt to the beginning of the prompt
 
362
  # Use 'prompt' here instead of 'message'
363
  formatted_prompt = format_prompt(formatted_prompt, history, max_history_turns=5) # Truncated history
364
  logging.info(f"Formatted Prompt: {formatted_prompt}")
365
+
366
+ # Conditionally create client
367
+ this_client = InferenceClient(model) if model != "mistralai/Mixtral-8x7B-Instruct-v0.1" else client
368
+
369
+ stream = this_client.text_generation(
370
+ formatted_prompt,
371
+ temperature=temperature,
372
+ max_new_tokens=max_new_tokens,
373
+ top_p=top_p,
374
+ repetition_penalty=repetition_penalty,
375
+ stream=True,
376
+ details=True,
377
+ return_full_text=False
378
+ )
379
+ resp = ""
380
  for response in stream:
381
  resp += response.token.text
382
  yield resp # This allows for streaming the response
383
 
384
  if VERBOSE:
385
+ logging.info(f"RESPONSE: {resp}") # Log the response directly
386
+
387
 
388
  def generate_text_chunked(input_text, model, generation_parameters, max_tokens_to_generate):
389
  """Generates text in chunks to avoid token limit errors."""
 
405
 
406
  return ''.join(generated_text)
407
 
408
+
 
 
 
 
 
 
 
 
 
 
409
  additional_inputs=[
410
  gr.Dropdown(
411
  label="Agents",