Yahir commited on
Commit
dbe7aa2
verified
1 Parent(s): a8ce6ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -44
app.py CHANGED
@@ -1,105 +1,106 @@
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
 
4
- cliente = InferenceClient(
5
  "google/gemma-7b-it"
6
  )
7
 
8
- def format_prompt(mensaje, historial):
9
  prompt = ""
10
- if historial:
11
- for usuario, respuesta_bot in historial:
12
- prompt += f"<start_of_turn>user{usuario}<end_of_turn>"
13
- prompt += f"<start_of_turn>model{respuesta_bot}"
14
- prompt += f"<start_of_turn>user{mensaje}<end_of_turn><start_of_turn>model"
 
15
  return prompt
16
 
17
- def generar(
18
- mensaje, historial, temperatura=0.9, max_nuevos_tokens=256, top_p=0.95, penalizacion_repetici贸n=1.0,
19
  ):
20
- if not historial:
21
- historial = []
22
- longitud_hist=0
23
- if historial:
24
- longitud_hist=len(historial)
25
- print(longitud_hist)
26
 
27
- temperatura = float(temperatura)
28
- if temperatura < 1e-2:
29
- temperatura = 1e-2
30
  top_p = float(top_p)
31
 
32
- generar_kwargs = dict(
33
- temperatura=temperatura,
34
- max_nuevos_tokens=max_nuevos_tokens,
35
  top_p=top_p,
36
- penalizacion_repetici贸n=penalizacion_repetici贸n,
37
  do_sample=True,
38
  seed=42,
39
  )
40
 
41
- prompt_formateado = format_prompt(mensaje, historial)
42
 
43
- flujo = cliente.text_generation(prompt_formateado, **generar_kwargs, flujo=True, detalles=True, return_full_text=False)
44
- salida = ""
45
 
46
- for respuesta in flujo:
47
- salida += respuesta.token.text
48
- yield salida
49
- return salida
50
 
51
 
52
- entradas_adicionales=[
53
  gr.Slider(
54
- label="Temperatura",
55
  value=0.9,
56
  minimum=0.0,
57
  maximum=1.0,
58
  step=0.05,
59
  interactive=True,
60
- info="Valores m谩s altos producen salidas m谩s diversas",
61
  ),
62
  gr.Slider(
63
- label="M谩x. tokens nuevos",
64
  value=512,
65
  minimum=0,
66
  maximum=1048,
67
  step=64,
68
  interactive=True,
69
- info="El m谩ximo de nuevos tokens",
70
  ),
71
  gr.Slider(
72
- label="Top-p (muestreo de n煤cleo)",
73
  value=0.90,
74
  minimum=0.0,
75
  maximum=1,
76
  step=0.05,
77
  interactive=True,
78
- info="Valores m谩s altos muestrean m谩s tokens de baja probabilidad",
79
  ),
80
  gr.Slider(
81
- label="Penalizaci贸n de repetici贸n",
82
  value=1.2,
83
  minimum=1.0,
84
  maximum=2.0,
85
  step=0.05,
86
  interactive=True,
87
- info="Penaliza los tokens repetidos",
88
  )
89
  ]
90
 
91
- # Crea un objeto Chatbot con la altura deseada
92
  chatbot = gr.Chatbot(height=450,
93
  layout="bubble")
94
 
95
  with gr.Blocks() as demo:
96
  gr.HTML("<h1><center>馃 Google-Gemma-7B-Chat 馃挰<h1><center>")
97
  gr.ChatInterface(
98
- generar,
99
- chatbot=chatbot, # Utiliza el objeto Chatbot creado
100
- additional_inputs=entradas_adicionales,
101
- examples=[["驴Cu谩l es el significado de la vida?"], ["Cu茅ntame algo sobre el Monte Fuji."]],
102
 
103
  )
104
 
105
- demo.queue().launch(debug=True)
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
 
4
+ client = InferenceClient(
5
  "google/gemma-7b-it"
6
  )
7
 
8
+ def format_prompt(message, history):
9
  prompt = ""
10
+ if history:
11
+ #<start_of_turn>userWhat is recession?<end_of_turn><start_of_turn>model
12
+ for user_prompt, bot_response in history:
13
+ prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
14
+ prompt += f"<start_of_turn>model{bot_response}"
15
+ prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
16
  return prompt
17
 
18
+ def generate(
19
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
20
  ):
21
+ if not history:
22
+ history = []
23
+ hist_len=0
24
+ if history:
25
+ hist_len=len(history)
26
+ print(hist_len)
27
 
28
+ temperature = float(temperature)
29
+ if temperature < 1e-2:
30
+ temperature = 1e-2
31
  top_p = float(top_p)
32
 
33
+ generate_kwargs = dict(
34
+ temperature=temperature,
35
+ max_new_tokens=max_new_tokens,
36
  top_p=top_p,
37
+ repetition_penalty=repetition_penalty,
38
  do_sample=True,
39
  seed=42,
40
  )
41
 
42
+ formatted_prompt = format_prompt(prompt, history)
43
 
44
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
+ output = ""
46
 
47
+ for response in stream:
48
+ output += response.token.text
49
+ yield output
50
+ return output
51
 
52
 
53
+ additional_inputs=[
54
  gr.Slider(
55
+ label="Temperature",
56
  value=0.9,
57
  minimum=0.0,
58
  maximum=1.0,
59
  step=0.05,
60
  interactive=True,
61
+ info="Higher values produce more diverse outputs",
62
  ),
63
  gr.Slider(
64
+ label="Max new tokens",
65
  value=512,
66
  minimum=0,
67
  maximum=1048,
68
  step=64,
69
  interactive=True,
70
+ info="The maximum numbers of new tokens",
71
  ),
72
  gr.Slider(
73
+ label="Top-p (nucleus sampling)",
74
  value=0.90,
75
  minimum=0.0,
76
  maximum=1,
77
  step=0.05,
78
  interactive=True,
79
+ info="Higher values sample more low-probability tokens",
80
  ),
81
  gr.Slider(
82
+ label="Repetition penalty",
83
  value=1.2,
84
  minimum=1.0,
85
  maximum=2.0,
86
  step=0.05,
87
  interactive=True,
88
+ info="Penalize repeated tokens",
89
  )
90
  ]
91
 
92
+ # Create a Chatbot object with the desired height
93
  chatbot = gr.Chatbot(height=450,
94
  layout="bubble")
95
 
96
  with gr.Blocks() as demo:
97
  gr.HTML("<h1><center>馃 Google-Gemma-7B-Chat 馃挰<h1><center>")
98
  gr.ChatInterface(
99
+ generate,
100
+ chatbot=chatbot, # Use the created Chatbot object
101
+ additional_inputs=additional_inputs,
102
+ examples=[["What is the meaning of life?"], ["Tell me something about Mt Fuji."]],
103
 
104
  )
105
 
106
+ demo.queue().launch(debug=True)