JMAA00 commited on
Commit
2ddc1fd
·
1 Parent(s): cc4027d
Files changed (1) hide show
  1. app.py +12 -19
app.py CHANGED
@@ -1,21 +1,21 @@
1
  import os
2
  import torch
3
  import gradio as gr
 
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
7
  TextIteratorStreamer,
8
  )
9
 
10
- # 1) Cargamos el tokenizer y el modelo de deepseek-ai/DeepSeek-R1-Distill-Llama-8B
11
  print("Cargando tokenizer...")
12
  tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
13
 
14
  print("Cargando modelo (puede tardar varios minutos)...")
15
  model = AutoModelForCausalLM.from_pretrained(
16
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
17
- device_map="auto", # Para usar GPU si está disponible
18
- torch_dtype=torch.float16 # Usa float16 en GPU; en CPU, cambia a float32
19
  )
20
  model.eval()
21
 
@@ -28,14 +28,12 @@ def respond(
28
  top_p: float,
29
  ):
30
  """
31
- - system_message: Texto del rol "system"
32
- - history: Historial [(user_message, assistant_reply), ...]
33
- - message: Mensaje actual del usuario
34
- Genera una respuesta en streaming usando transformers.TextIteratorStreamer
 
35
  """
36
-
37
- # Construimos un prompt concatenando 'system_message', 'history' y el nuevo 'message'
38
- # Esto es un ejemplo de formateo sencillo. Ajusta según tu preferencia de estilo chat.
39
  prompt = f"[SYSTEM] {system_message}\n"
40
  for (usr, bot) in history:
41
  if usr:
@@ -44,14 +42,11 @@ def respond(
44
  prompt += f"[ASSISTANT] {bot}\n"
45
  prompt += f"[USER] {message}\n[ASSISTANT]"
46
 
47
- # Usamos TextIteratorStreamer para obtener tokens a medida que se generan
48
  streamer = TextIteratorStreamer(
49
  tokenizer=tokenizer,
50
  skip_special_tokens=True
51
  )
52
 
53
- # Preparamos argumentos para model.generate
54
- # (similar a pipeline pero de bajo nivel)
55
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
56
  generation_kwargs = dict(
57
  **inputs,
@@ -59,24 +54,22 @@ def respond(
59
  max_new_tokens=max_tokens,
60
  temperature=temperature,
61
  top_p=top_p,
62
- do_sample=True, # para permitir sampling
63
- # repetition_penalty=1.0, # ajusta si lo deseas
64
  )
65
 
66
- # Lanzamos la generación en un hilo
67
- generation_thread = torch.Thread(
68
  target=model.generate,
69
  kwargs=generation_kwargs
70
  )
71
  generation_thread.start()
72
 
73
- # Leemos tokens a medida que se generan y yield
74
  output_text = ""
75
  for new_token in streamer:
76
  output_text += new_token
77
  yield output_text
78
 
79
- # Interfaz con ChatInterface
80
  demo = gr.ChatInterface(
81
  fn=respond,
82
  additional_inputs=[
 
1
  import os
2
  import torch
3
  import gradio as gr
4
+ import threading
5
  from transformers import (
6
  AutoTokenizer,
7
  AutoModelForCausalLM,
8
  TextIteratorStreamer,
9
  )
10
 
 
11
  print("Cargando tokenizer...")
12
  tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
13
 
14
  print("Cargando modelo (puede tardar varios minutos)...")
15
  model = AutoModelForCausalLM.from_pretrained(
16
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
17
+ device_map="auto", # Usa GPU si está disponible
18
+ torch_dtype=torch.float16 # FP16 en GPU; en CPU quizá float32
19
  )
20
  model.eval()
21
 
 
28
  top_p: float,
29
  ):
30
  """
31
+ Construimos el prompt a partir de:
32
+ - system_message
33
+ - history (lista de (user, assistant))
34
+ - message actual
35
+ Generamos tokens progresivamente con TextIteratorStreamer.
36
  """
 
 
 
37
  prompt = f"[SYSTEM] {system_message}\n"
38
  for (usr, bot) in history:
39
  if usr:
 
42
  prompt += f"[ASSISTANT] {bot}\n"
43
  prompt += f"[USER] {message}\n[ASSISTANT]"
44
 
 
45
  streamer = TextIteratorStreamer(
46
  tokenizer=tokenizer,
47
  skip_special_tokens=True
48
  )
49
 
 
 
50
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
51
  generation_kwargs = dict(
52
  **inputs,
 
54
  max_new_tokens=max_tokens,
55
  temperature=temperature,
56
  top_p=top_p,
57
+ do_sample=True,
 
58
  )
59
 
60
+ # Usamos threading.Thread en lugar de torch.Thread
61
+ generation_thread = threading.Thread(
62
  target=model.generate,
63
  kwargs=generation_kwargs
64
  )
65
  generation_thread.start()
66
 
67
+ # Leemos tokens a medida que se generan y los enviamos a Gradio (yield)
68
  output_text = ""
69
  for new_token in streamer:
70
  output_text += new_token
71
  yield output_text
72
 
 
73
  demo = gr.ChatInterface(
74
  fn=respond,
75
  additional_inputs=[