ak2603 commited on
Commit
ad882b1
·
1 Parent(s): f1b7b9d

debuggin llama

Browse files
Files changed (1) hide show
  1. llama.py +13 -13
llama.py CHANGED
@@ -22,28 +22,28 @@ following this structure: 'Der Kunde ... und erwartet ...'. The summaries need t
22
  def load_llama_model():
23
  """Load Llama model and tokenizer with optimized settings"""
24
  tokenizer = AutoTokenizer.from_pretrained("Walid777/llama3-8b-emails-summarization")
25
- model = AutoModelForCausalLM.from_pretrained(
26
- "Walid777/llama3-8b-emails-summarization",
27
- device_map="auto",
28
- torch_dtype="auto"
29
- )
30
  return model, tokenizer
31
 
32
  def generate_llama_summary(email, model, tokenizer, prompt_template):
33
  """Generate summary using structured prompt template"""
34
- formatted_prompt = prompt_template.format(email)
35
 
36
  inputs = tokenizer(
37
- formatted_prompt,
38
  return_tensors="pt"
39
- ).to(model.device)
40
 
41
  outputs = model.generate(
42
  **inputs,
43
- max_new_tokens=128,
44
- temperature=0.7,
45
- pad_token_id=tokenizer.eos_token_id
46
  )
47
 
48
- full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
- return full_text.split("### Summary:")[-1].strip()
 
 
 
 
 
 
 
22
  def load_llama_model():
23
  """Load Llama model and tokenizer with optimized settings"""
24
  tokenizer = AutoTokenizer.from_pretrained("Walid777/llama3-8b-emails-summarization")
25
+ model = AutoModelForCausalLM.from_pretrained( "Walid777/llama3-8b-emails-summarization")
 
 
 
 
26
  return model, tokenizer
27
 
28
  def generate_llama_summary(email, model, tokenizer, prompt_template):
29
  """Generate summary using structured prompt template"""
30
+ formatted_prompt = prompt_template.format(email, "")
31
 
32
  inputs = tokenizer(
33
+ [formatted_prompt],
34
  return_tensors="pt"
35
+ ).to("cuda")
36
 
37
  outputs = model.generate(
38
  **inputs,
39
+ max_new_tokens=128
 
 
40
  )
41
 
42
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ if "### Summary:" in summary:
44
+ summary = summary.split("### Summary:")[-1].strip()
45
+ else:
46
+ summary = "Error: Could not extract summary"
47
+ return summary
48
+
49
+