vsrinivas commited on
Commit
e181201
·
verified ·
1 Parent(s): d7c09eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -19
app.py CHANGED
@@ -4,24 +4,29 @@ import torch
4
  import gradio as gr
5
 
6
  desired_dtype = torch.bfloat16
 
7
 
8
- # checkpoint = "tiiuae/falcon-40b-instruct"
9
- checkpoint ="tiiuae/falcon-7b-instruct"
10
- # checkpoint = "tiiuae/falcon-7b"
11
- # checkpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1"
12
- # checkpoint = "gpt2"
13
- # checkpoint = "amazon/FalconLite2"
14
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
- checkpoint, device_map="auto",
17
- offload_folder="off_load",
18
- trust_remote_code=True,
19
- # torch_dtype="auto",
20
- )
21
- tokenizer = AutoTokenizer.from_pretrained(checkpoint,
22
- trust_remote_code=True,
23
- torch_dtype="auto",
24
- )
 
 
 
 
 
 
 
25
 
26
  # model = "tiiuae/FalconLite2"
27
  # tokenizer = AutoTokenizer.from_pretrained(model,
@@ -32,14 +37,23 @@ pipeline = transformers.pipeline(
32
  "text-generation",
33
  model=model,
34
  tokenizer=tokenizer,
35
- # use_safetensors=True,
36
- # torch_dtype=torch.bfloat16,
37
  trust_remote_code=True,
38
  device_map="auto",
39
- offload_folder="off_load",
40
- # offload_state_dict = True,
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # def format_chat_prompt(message, chat_history):
44
  # prompt = ""
45
  # for turn in chat_history:
 
4
  import gradio as gr
5
 
6
  desired_dtype = torch.bfloat16
7
+ torch.set_default_dtype(torch.bfloat16)
8
 
9
+
10
+ # checkpoint = "vsrinivas/falconlite2"
11
+ checkpoint = "tiiuae/falcon-7b-instruct"
 
 
 
12
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
+ # checkpoint, device_map="auto", offload_folder="offload", trust_remote_code=True, torch_dtype="auto")
15
+ checkpoint, device_map="auto", offload_folder="offload", trust_remote_code=True)
16
+
17
+ # tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, torch_dtype="auto")
18
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
19
+
20
+ # model = AutoModelForCausalLM.from_pretrained(
21
+ # checkpoint, device_map="auto",
22
+ # # offload_folder="off_load",
23
+ # trust_remote_code=True,
24
+ # # torch_dtype="auto",
25
+ # )
26
+ # tokenizer = AutoTokenizer.from_pretrained(checkpoint,
27
+ # trust_remote_code=True,
28
+ # torch_dtype="auto",
29
+ # )
30
 
31
  # model = "tiiuae/FalconLite2"
32
  # tokenizer = AutoTokenizer.from_pretrained(model,
 
37
  "text-generation",
38
  model=model,
39
  tokenizer=tokenizer,
40
+ torch_dtype=torch.bfloat16,
 
41
  trust_remote_code=True,
42
  device_map="auto",
 
 
43
  )
44
 
45
+ # pipeline = transformers.pipeline(
46
+ # "text-generation",
47
+ # model=model,
48
+ # tokenizer=tokenizer,
49
+ # # use_safetensors=True,
50
+ # # torch_dtype=torch.bfloat16,
51
+ # trust_remote_code=True,
52
+ # device_map="auto",
53
+ # offload_folder="off_load",
54
+ # # offload_state_dict = True,
55
+ # )
56
+
57
  # def format_chat_prompt(message, chat_history):
58
  # prompt = ""
59
  # for turn in chat_history: