AIdeaText commited on
Commit
1aaec00
β€’
1 Parent(s): 9428cb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -29
app.py CHANGED
@@ -1,51 +1,92 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class LlamaDemo:
6
  def __init__(self):
7
- self.model_name = "meta-llama/Llama-2-70b-chat"
8
- # Initialize in lazy loading fashion
9
- self._pipe = None
10
 
11
  @property
12
- def pipe(self):
13
- if self._pipe is None:
14
- self._pipe = pipeline(
15
- "text-generation",
16
- model=self.model_name,
17
  torch_dtype=torch.float16,
18
  device_map="auto",
 
 
 
 
 
 
 
 
 
 
19
  trust_remote_code=True
20
  )
21
- return self._pipe
22
 
23
- def generate_response(self, prompt: str, max_length: int = 512) -> str:
24
  # Format prompt for Llama 2 chat
25
  formatted_prompt = f"[INST] {prompt} [/INST]"
26
 
27
- # Generate response using pipeline
28
- response = self.pipe(
29
- formatted_prompt,
30
- max_new_tokens=max_length,
31
- num_return_sequences=1,
32
- temperature=0.7,
33
- do_sample=True,
34
- top_p=0.9
35
- )[0]['generated_text']
 
 
 
36
 
37
- # Extract response after the instruction tag
38
  return response.split("[/INST]")[-1].strip()
39
 
40
  def main():
41
  st.set_page_config(
42
- page_title="Llama 2 Chat Demo",
43
  page_icon="πŸ¦™",
44
  layout="wide"
45
  )
46
 
47
  st.title("πŸ¦™ Llama 2 Chat Demo")
48
 
 
 
 
 
49
  # Initialize model
50
  if 'llama' not in st.session_state:
51
  with st.spinner("Loading Llama 2... This might take a few minutes..."):
@@ -82,13 +123,6 @@ def main():
82
  st.error(f"Error: {str(e)}")
83
 
84
  with st.sidebar:
85
- st.markdown("""
86
- ### About
87
- This demo uses Llama-2-70B-chat, a large language model from Meta.
88
-
89
- The model runs with automatic device mapping and mixed precision for optimal performance.
90
- """)
91
-
92
  if st.button("Clear Chat History"):
93
  st.session_state.chat_history = []
94
  st.experimental_rerun()
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
+ from huggingface_hub import login
5
+ import os
6
+
7
+ def init_huggingface():
8
+ """Initialize Hugging Face authentication either from secrets or user input"""
9
+ if 'HUGGING_FACE_TOKEN' not in st.session_state:
10
+ # First try to get from environment variable
11
+ token = os.getenv('HUGGINGFACE_TOKEN')
12
+
13
+ # If not in environment, check streamlit secrets
14
+ if not token and 'huggingface_token' in st.secrets:
15
+ token = st.secrets['huggingface_token']
16
+
17
+ # If still not found, ask user
18
+ if not token:
19
+ token = st.text_input('Enter your Hugging Face token:', type='password')
20
+ if not token:
21
+ st.warning('Please enter your Hugging Face token to proceed')
22
+ st.stop()
23
+
24
+ st.session_state['HUGGING_FACE_TOKEN'] = token
25
+
26
+ # Login to Hugging Face
27
+ login(st.session_state['HUGGING_FACE_TOKEN'])
28
+ return True
29
 
30
  class LlamaDemo:
31
  def __init__(self):
32
+ self.model_name = "meta-llama/Llama-2-70b-chat-hf"
33
+ self._model = None
34
+ self._tokenizer = None
35
 
36
  @property
37
+ def model(self):
38
+ if self._model is None:
39
+ self._model = AutoModelForCausalLM.from_pretrained(
40
+ self.model_name,
 
41
  torch_dtype=torch.float16,
42
  device_map="auto",
43
+ trust_remote_code=True,
44
+ load_in_8bit=True # Para optimizar memoria
45
+ )
46
+ return self._model
47
+
48
+ @property
49
+ def tokenizer(self):
50
+ if self._tokenizer is None:
51
+ self._tokenizer = AutoTokenizer.from_pretrained(
52
+ self.model_name,
53
  trust_remote_code=True
54
  )
55
+ return self._tokenizer
56
 
57
+ def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str:
58
  # Format prompt for Llama 2 chat
59
  formatted_prompt = f"[INST] {prompt} [/INST]"
60
 
61
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
62
+
63
+ with torch.no_grad():
64
+ outputs = self.model.generate(
65
+ **inputs,
66
+ max_new_tokens=max_new_tokens,
67
+ num_return_sequences=1,
68
+ temperature=0.7,
69
+ do_sample=True,
70
+ top_p=0.9,
71
+ pad_token_id=self.tokenizer.eos_token_id
72
+ )
73
 
74
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
75
  return response.split("[/INST]")[-1].strip()
76
 
77
  def main():
78
  st.set_page_config(
79
+ page_title="Llama 2 Demo",
80
  page_icon="πŸ¦™",
81
  layout="wide"
82
  )
83
 
84
  st.title("πŸ¦™ Llama 2 Chat Demo")
85
 
86
+ # Initialize Hugging Face authentication
87
+ if init_huggingface():
88
+ st.success("Successfully authenticated with Hugging Face!")
89
+
90
  # Initialize model
91
  if 'llama' not in st.session_state:
92
  with st.spinner("Loading Llama 2... This might take a few minutes..."):
 
123
  st.error(f"Error: {str(e)}")
124
 
125
  with st.sidebar:
 
 
 
 
 
 
 
126
  if st.button("Clear Chat History"):
127
  st.session_state.chat_history = []
128
  st.experimental_rerun()