dsfdfsghgf commited on
Commit
4a72a49
·
verified ·
1 Parent(s): 2042c5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -1,15 +1,19 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- # Modell und Tokenizer von Hugging Face laden
5
- model_name = "Qwen/Qwen2.5-Math-7B-Instruct"
 
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
  # Modell und Tokenizer laden
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
  device_map="auto", # Modell auf verfügbare Geräte verteilen
12
- trust_remote_code=True
 
 
13
  ).to(device).eval()
14
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
@@ -28,8 +32,7 @@ input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_
28
 
29
  # Inferenz durchführen
30
  with torch.no_grad():
31
- outputs = model.generate(input_ids=input_ids, max_length=512, num_return_sequences=1)
32
 
33
- # Ausgabe dekodieren und anzeigen
34
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
- print(response)
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ # Modellname für die kleinere Variante
5
+ model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
6
+
7
+ # Überprüfen, ob eine GPU verfügbar ist
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  # Modell und Tokenizer laden
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_name,
13
  device_map="auto", # Modell auf verfügbare Geräte verteilen
14
+ low_cpu_mem_usage=True, # Versucht, den Speicherverbrauch zu reduzieren
15
+ trust_remote_code=True,
16
+ torch_dtype=torch.float16 # Reduziert den Speicherverbrauch
17
  ).to(device).eval()
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
32
 
33
  # Inferenz durchführen
34
  with torch.no_grad():
35
+ outputs = model.generate(input_ids=input_ids, max_length=256, num_return_sequences=1)
36
 
37
+ # Ausgabe anzeigen
38
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))