AIdeaText commited on
Commit
c3590b2
·
verified ·
1 Parent(s): 3cd2dad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -60,14 +60,16 @@ class Llama3Demo:
60
 
61
 
62
  ##################################################################
63
- def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str:
64
- formatted_prompt = f"""<|system|>You are a helpful AI assistant.</s>
65
- <|user|>{prompt}</s>
66
- <|assistant|>"""
 
 
 
67
 
68
  inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
69
 
70
- # Asegurar que tenemos un pad_token_id válido
71
  if self.tokenizer.pad_token_id is None:
72
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
73
 
@@ -76,10 +78,12 @@ class Llama3Demo:
76
  **inputs,
77
  max_new_tokens=max_new_tokens,
78
  num_return_sequences=1,
79
- temperature=0.7,
80
  do_sample=True,
81
- top_p=0.9,
82
- pad_token_id=self.tokenizer.pad_token_id # Explícitamente establecer pad_token_id
 
 
83
  )
84
 
85
  torch.cuda.empty_cache()
 
60
 
61
 
62
  ##################################################################
63
+ def generate_response(self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.6,
64
+ top_p: float = 0.85, repetition_penalty: float = 1.2, top_k: int = 50) -> str:
65
+ formatted_prompt = f"""<|system|>You are a helpful AI assistant. Always provide accurate,
66
+ detailed, and well-reasoned responses. If you're unsure about something, acknowledge the uncertainty.
67
+ Break down complex topics into clear explanations.</s>
68
+ <|user|>{prompt}</s>
69
+ <|assistant|>"""
70
 
71
  inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
72
 
 
73
  if self.tokenizer.pad_token_id is None:
74
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
75
 
 
78
  **inputs,
79
  max_new_tokens=max_new_tokens,
80
  num_return_sequences=1,
81
+ temperature=temperature,
82
  do_sample=True,
83
+ top_p=top_p,
84
+ top_k=top_k,
85
+ repetition_penalty=repetition_penalty,
86
+ pad_token_id=self.tokenizer.pad_token_id
87
  )
88
 
89
  torch.cuda.empty_cache()