BenkHel commited on
Commit
44d690a
·
verified ·
1 Parent(s): f236c5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -6,6 +6,8 @@ import subprocess
6
  import spaces
7
  import cumo.serve.gradio_web_server as gws
8
 
 
 
9
  import datetime
10
  import json
11
 
@@ -40,12 +42,28 @@ disable_btn = gr.Button(interactive=False)
40
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
  model_path = 'BenkHel/CumoThesis'
43
- model_base = 'mistralai/Mistral-7B-Instruct-v0.2'
44
- model_name = 'CumoThesis'
45
- conv_mode = 'mistral_instruct_system'
46
  load_8bit = False
47
  load_4bit = False
48
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device=device, use_flash_attn=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  model.config.training = False
50
 
51
  def upvote_last_response(state):
 
6
  import spaces
7
  import cumo.serve.gradio_web_server as gws
8
 
9
+ from transformers import AutoProcessor, LlavaMistralForCausalLM
10
+
11
  import datetime
12
  import json
13
 
 
42
 
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
  model_path = 'BenkHel/CumoThesis'
45
+ conv_mode = 'mistral_instruct_system' # Diese Variable wird noch für die Konversationstemplates benötigt
 
 
46
  load_8bit = False
47
  load_4bit = False
48
+
49
+ # Laden Sie den Prozessor, der Tokenizer und Bildprozessor kombiniert
50
+ processor = AutoProcessor.from_pretrained(model_path)
51
+
52
+ # Laden Sie das Modell mit der korrekten Klasse
53
+ model = LlavaMistralForCausalLM.from_pretrained(
54
+ model_path,
55
+ torch_dtype=torch.bfloat16, # Ihre config.json spezifiziert bfloat16
56
+ low_cpu_mem_usage=True, # Empfohlen für große Modelle
57
+ load_in_4bit=load_4bit,
58
+ load_in_8bit=load_8bit,
59
+ )
60
+
61
+ # Weisen Sie die Komponenten den alten Variablennamen zu, damit der restliche Code funktioniert
62
+ tokenizer = processor.tokenizer
63
+ image_processor = processor.image_processor
64
+
65
+ # Setzen Sie die Kontextlänge (falls der restliche Code sie benötigt)
66
+ context_len = model.config.max_position_embeddings
67
  model.config.training = False
68
 
69
  def upvote_last_response(state):