laragrl commited on
Commit
30f409d
·
verified ·
1 Parent(s): 848ec16
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -10,26 +10,26 @@ import torch
10
  # }
11
 
12
  model_names = {
13
- "LeoLM_13B": "LeoLM/leo-hessianai-7b",
14
  "Occiglot_7B": "occiglot/occiglot-7b-de-en"
15
  }
16
 
17
  # Tokenizer und Modelle vorbereiten
18
  tokenizers = {name: AutoTokenizer.from_pretrained(model) for name, model in model_names.items()}
19
- models = {name: AutoModelForCausalLM.from_pretrained(model, device_map="auto", torch_dtype=torch.float16) for name, model in model_names.items()}
20
 
21
  # Funktion zur Generierung der Antwort
22
  def generate_response(model_choice, prompt):
23
  tokenizer = tokenizers[model_choice]
24
  model = models[model_choice]
25
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
26
  outputs = model.generate(inputs["input_ids"], max_new_tokens=100, do_sample=True)
27
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return response
29
 
30
  # Gradio Interface
31
  with gr.Blocks() as demo:
32
- gr.Markdown("# Vergleich von LLMs: LeoLM, Occiglot und LLaMA 2")
33
  with gr.Row():
34
  model_choice = gr.Radio(list(model_names.keys()), label="Modell auswählen")
35
  prompt = gr.Textbox(label="Frage stellen", placeholder="Was sind die Hauptursachen für Bluthochdruck?")
 
10
  # }
11
 
12
  model_names = {
13
+ "LeoLM_7B": "LeoLM/leo-hessianai-7b",
14
  "Occiglot_7B": "occiglot/occiglot-7b-de-en"
15
  }
16
 
17
  # Tokenizer und Modelle vorbereiten
18
  tokenizers = {name: AutoTokenizer.from_pretrained(model) for name, model in model_names.items()}
19
+ models = {name: AutoModelForCausalLM.from_pretrained(model) for name, model in model_names.items()}
20
 
21
  # Funktion zur Generierung der Antwort
22
  def generate_response(model_choice, prompt):
23
  tokenizer = tokenizers[model_choice]
24
  model = models[model_choice]
25
+ inputs = tokenizer(prompt, return_tensors="pt")
26
  outputs = model.generate(inputs["input_ids"], max_new_tokens=100, do_sample=True)
27
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return response
29
 
30
  # Gradio Interface
31
  with gr.Blocks() as demo:
32
+ gr.Markdown("# Vergleich von LLMs: LeoLM und Occiglot")
33
  with gr.Row():
34
  model_choice = gr.Radio(list(model_names.keys()), label="Modell auswählen")
35
  prompt = gr.Textbox(label="Frage stellen", placeholder="Was sind die Hauptursachen für Bluthochdruck?")