Staticaliza commited on
Commit
b382e61
·
verified ·
1 Parent(s): 8a2b20c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -3,7 +3,14 @@ import torch
3
  import threading
4
  import spaces
5
 
6
- from transformers import TextIteratorStreamer
 
 
 
 
 
 
 
7
 
8
  print("Is CUDA available?", torch.cuda.is_available())
9
 
@@ -20,6 +27,7 @@ class ModelWrapper:
20
  device_map={'': 'cuda:0'},
21
  trust_remote_code=True,
22
  )
 
23
 
24
  print("Model is on device:", next(self.model.parameters()).device)
25
 
@@ -48,18 +56,19 @@ class ModelWrapper:
48
  generated_text += new_text
49
  yield generated_text
50
 
 
51
  model_wrapper = ModelWrapper()
52
 
 
53
  interface = gr.Interface(
54
  fn=model_wrapper.generate,
55
  inputs=gr.Textbox(lines=5, label="Input Prompt"),
56
- outputs=gr.Textbox(label="Generated Text"),
57
  title="Mistral-Large-Instruct-2407 Text Completion",
58
  description="Enter a prompt and receive a text completion using the Mistral-Large-Instruct-2407 INT4 model.",
59
  allow_flagging='never',
60
  live=False,
61
- cache_examples=False,
62
- streaming=True
63
  )
64
 
65
  if __name__ == "__main__":
 
3
  import threading
4
  import spaces
5
 
6
+ from transformers import AutoTokenizer, TextIteratorStreamer
7
+ from auto_gptq import AutoGPTQForCausalLM
8
+
9
+ # Model identifier
10
+ model_id = "xmadai/Mistral-Large-Instruct-2407-xMADai-INT4"
11
+
12
+ # Load the tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True)
14
 
15
  print("Is CUDA available?", torch.cuda.is_available())
16
 
 
27
  device_map={'': 'cuda:0'},
28
  trust_remote_code=True,
29
  )
30
+ self.model.eval()
31
 
32
  print("Model is on device:", next(self.model.parameters()).device)
33
 
 
56
  generated_text += new_text
57
  yield generated_text
58
 
59
+ # Instantiate the model wrapper
60
  model_wrapper = ModelWrapper()
61
 
62
+ # Create the Gradio interface
63
  interface = gr.Interface(
64
  fn=model_wrapper.generate,
65
  inputs=gr.Textbox(lines=5, label="Input Prompt"),
66
+ outputs=gr.Textbox(label="Generated Text", lines=10, streaming=True),
67
  title="Mistral-Large-Instruct-2407 Text Completion",
68
  description="Enter a prompt and receive a text completion using the Mistral-Large-Instruct-2407 INT4 model.",
69
  allow_flagging='never',
70
  live=False,
71
+ cache_examples=False
 
72
  )
73
 
74
  if __name__ == "__main__":