PyaeSoneK commited on
Commit
e6f2506
·
1 Parent(s): 0be1a66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -6
app.py CHANGED
@@ -57,7 +57,12 @@ st.set_page_config(
57
  page_icon = '🕵')
58
 
59
 
 
 
 
 
60
 
 
61
 
62
  @st.cache_resource
63
  def load_llm_model():
@@ -67,13 +72,39 @@ def load_llm_model():
67
  # "load_in_8bit": True,"max_length": 256, "temperature": 0,
68
  # "repetition_penalty": 1.5})
69
 
70
- token = st.secrets['hf_access_token']
71
- llm = AutoModelForCausalLM.from_pretrained('PyaeSoneK/LlamaV2LegalFineTuned',
72
- device_map='auto',
73
- torch_dtype=torch.float16,
74
- use_auth_token= st.secrets['hf_access_token'],
 
75
  )
76
- return llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  @st.cache_resource
 
57
  page_icon = '🕵')
58
 
59
 
60
+ from transformers import AutoModel
61
+ import torch
62
+ import transformers
63
+ from transformers import AutoTokenizer, AutoModelForCausalLM
64
 
65
+ from transformers import pipeline
66
 
67
  @st.cache_resource
68
  def load_llm_model():
 
72
  # "load_in_8bit": True,"max_length": 256, "temperature": 0,
73
  # "repetition_penalty": 1.5})
74
 
75
+ #token = st.secrets['hf_access_token']
76
+ #llm = AutoModelForCausalLM.from_pretrained(model_id = 'PyaeSoneK/LlamaV2LegalFineTuned',
77
+ # task = 'text2text-generation',
78
+ # device_map='auto',
79
+ # torch_dtype=torch.float16,
80
+ # use_auth_token= st.secrets['hf_access_token'],
81
  )
82
+ #return llm
83
+ pipe = pipeline("text-generation",
84
+ model=model,
85
+ tokenizer= tokenizer,
86
+ torch_dtype=torch.bfloat16,
87
+ device_map="auto",
88
+ max_new_tokens = 512,
89
+ do_sample=True,
90
+ top_k=30,
91
+ num_return_sequences=1,
92
+ eos_token_id=tokenizer.eos_token_id
93
+ )
94
+
95
+
96
+
97
+ llm = AutoModelForCausalLM.from_pretrained("PyaeSoneK/LlamaV2LegalFineTuned",
98
+ device_map='auto',
99
+ torch_dtype=torch.float16,
100
+ use_auth_token= st.secrets['hf_access_token'],)
101
+ # load_in_4bit=True
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained("PyaeSoneK/LlamaV2LegalFineTuned",
104
+ use_auth_token=True,)
105
+
106
+ return llm
107
+
108
 
109
 
110
  @st.cache_resource