Yahir commited on
Commit
a8ce6ea
verified
1 Parent(s): a4eda25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -45
app.py CHANGED
@@ -1,106 +1,105 @@
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
 
4
- client = InferenceClient(
5
  "google/gemma-7b-it"
6
  )
7
 
8
- def format_prompt(message, history):
9
  prompt = ""
10
- if history:
11
- #<start_of_turn>userWhat is recession?<end_of_turn><start_of_turn>model
12
- for user_prompt, bot_response in history:
13
- prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
14
- prompt += f"<start_of_turn>model{bot_response}"
15
- prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
16
  return prompt
17
 
18
- def generate(
19
- prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
20
  ):
21
- if not history:
22
- history = []
23
- hist_len=0
24
- if history:
25
- hist_len=len(history)
26
- print(hist_len)
27
 
28
- temperature = float(temperature)
29
- if temperature < 1e-2:
30
- temperature = 1e-2
31
  top_p = float(top_p)
32
 
33
- generate_kwargs = dict(
34
- temperature=temperature,
35
- max_new_tokens=max_new_tokens,
36
  top_p=top_p,
37
- repetition_penalty=repetition_penalty,
38
  do_sample=True,
39
  seed=42,
40
  )
41
 
42
- formatted_prompt = format_prompt(prompt, history)
43
 
44
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
- output = ""
46
 
47
- for response in stream:
48
- output += response.token.text
49
- yield output
50
- return output
51
 
52
 
53
- additional_inputs=[
54
  gr.Slider(
55
- label="Temperature",
56
  value=0.9,
57
  minimum=0.0,
58
  maximum=1.0,
59
  step=0.05,
60
  interactive=True,
61
- info="Higher values produce more diverse outputs",
62
  ),
63
  gr.Slider(
64
- label="Max new tokens",
65
  value=512,
66
  minimum=0,
67
  maximum=1048,
68
  step=64,
69
  interactive=True,
70
- info="The maximum numbers of new tokens",
71
  ),
72
  gr.Slider(
73
- label="Top-p (nucleus sampling)",
74
  value=0.90,
75
  minimum=0.0,
76
  maximum=1,
77
  step=0.05,
78
  interactive=True,
79
- info="Higher values sample more low-probability tokens",
80
  ),
81
  gr.Slider(
82
- label="Repetition penalty",
83
  value=1.2,
84
  minimum=1.0,
85
  maximum=2.0,
86
  step=0.05,
87
  interactive=True,
88
- info="Penalize repeated tokens",
89
  )
90
  ]
91
 
92
- # Create a Chatbot object with the desired height
93
  chatbot = gr.Chatbot(height=450,
94
  layout="bubble")
95
 
96
  with gr.Blocks() as demo:
97
  gr.HTML("<h1><center>馃 Google-Gemma-7B-Chat 馃挰<h1><center>")
98
  gr.ChatInterface(
99
- generate,
100
- chatbot=chatbot, # Use the created Chatbot object
101
- additional_inputs=additional_inputs,
102
- examples=[["What is the meaning of life?"], ["Tell me something about Mt Fuji."]],
103
 
104
  )
105
 
106
- demo.queue().launch(debug=True)
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
 
4
+ cliente = InferenceClient(
5
  "google/gemma-7b-it"
6
  )
7
 
8
+ def format_prompt(mensaje, historial):
9
  prompt = ""
10
+ if historial:
11
+ for usuario, respuesta_bot in historial:
12
+ prompt += f"<start_of_turn>user{usuario}<end_of_turn>"
13
+ prompt += f"<start_of_turn>model{respuesta_bot}"
14
+ prompt += f"<start_of_turn>user{mensaje}<end_of_turn><start_of_turn>model"
 
15
  return prompt
16
 
17
+ def generar(
18
+ mensaje, historial, temperatura=0.9, max_nuevos_tokens=256, top_p=0.95, penalizacion_repetici贸n=1.0,
19
  ):
20
+ if not historial:
21
+ historial = []
22
+ longitud_hist=0
23
+ if historial:
24
+ longitud_hist=len(historial)
25
+ print(longitud_hist)
26
 
27
+ temperatura = float(temperatura)
28
+ if temperatura < 1e-2:
29
+ temperatura = 1e-2
30
  top_p = float(top_p)
31
 
32
+ generar_kwargs = dict(
33
+ temperatura=temperatura,
34
+ max_nuevos_tokens=max_nuevos_tokens,
35
  top_p=top_p,
36
+ penalizacion_repetici贸n=penalizacion_repetici贸n,
37
  do_sample=True,
38
  seed=42,
39
  )
40
 
41
+ prompt_formateado = format_prompt(mensaje, historial)
42
 
43
+ flujo = cliente.text_generation(prompt_formateado, **generar_kwargs, flujo=True, detalles=True, return_full_text=False)
44
+ salida = ""
45
 
46
+ for respuesta in flujo:
47
+ salida += respuesta.token.text
48
+ yield salida
49
+ return salida
50
 
51
 
52
+ entradas_adicionales=[
53
  gr.Slider(
54
+ label="Temperatura",
55
  value=0.9,
56
  minimum=0.0,
57
  maximum=1.0,
58
  step=0.05,
59
  interactive=True,
60
+ info="Valores m谩s altos producen salidas m谩s diversas",
61
  ),
62
  gr.Slider(
63
+ label="M谩x. tokens nuevos",
64
  value=512,
65
  minimum=0,
66
  maximum=1048,
67
  step=64,
68
  interactive=True,
69
+ info="El m谩ximo de nuevos tokens",
70
  ),
71
  gr.Slider(
72
+ label="Top-p (muestreo de n煤cleo)",
73
  value=0.90,
74
  minimum=0.0,
75
  maximum=1,
76
  step=0.05,
77
  interactive=True,
78
+ info="Valores m谩s altos muestrean m谩s tokens de baja probabilidad",
79
  ),
80
  gr.Slider(
81
+ label="Penalizaci贸n de repetici贸n",
82
  value=1.2,
83
  minimum=1.0,
84
  maximum=2.0,
85
  step=0.05,
86
  interactive=True,
87
+ info="Penaliza los tokens repetidos",
88
  )
89
  ]
90
 
91
+ # Crea un objeto Chatbot con la altura deseada
92
  chatbot = gr.Chatbot(height=450,
93
  layout="bubble")
94
 
95
  with gr.Blocks() as demo:
96
  gr.HTML("<h1><center>馃 Google-Gemma-7B-Chat 馃挰<h1><center>")
97
  gr.ChatInterface(
98
+ generar,
99
+ chatbot=chatbot, # Utiliza el objeto Chatbot creado
100
+ additional_inputs=entradas_adicionales,
101
+ examples=[["驴Cu谩l es el significado de la vida?"], ["Cu茅ntame algo sobre el Monte Fuji."]],
102
 
103
  )
104
 
105
+ demo.queue().launch(debug=True)