File size: 4,592 Bytes
f995cde
a16e1cf
f995cde
1aaec00
a16e1cf
 
 
 
 
 
 
 
 
 
 
f995cde
5d15c3d
 
6f5bc8b
 
5d15c3d
 
 
 
 
 
f995cde
 
a7492f8
1aaec00
 
f995cde
 
1aaec00
 
 
 
a16e1cf
fa0a856
a16e1cf
1aaec00
 
 
 
 
 
 
 
fa0a856
f995cde
1aaec00
f995cde
1aaec00
55ca2dd
fa0a856
1aaec00
 
 
 
 
 
 
 
 
 
 
 
f995cde
a16e1cf
 
 
1aaec00
55ca2dd
f995cde
 
 
a16e1cf
f995cde
 
 
 
55ca2dd
f995cde
a16e1cf
 
 
 
 
1aaec00
55ca2dd
f995cde
55ca2dd
fa0a856
f995cde
 
 
 
 
 
 
55ca2dd
 
f995cde
 
 
 
 
 
 
 
 
 
 
fa0a856
55ca2dd
 
 
 
 
 
 
 
 
f995cde
 
a16e1cf
 
 
 
 
 
 
 
f995cde
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Verificar GPU al inicio
def check_gpu():
    if torch.cuda.is_available():
        gpu_info = {
            "GPU Disponible": True,
            "Nombre GPU": torch.cuda.get_device_name(0),
            "Memoria Total (GB)": round(torch.cuda.get_device_properties(0).total_memory/1e9, 2),
            "CUDA Version": torch.version.cuda
        }
        return gpu_info
    return {"GPU Disponible": False}

# Configurar autenticaci贸n
def setup_auth():
    if 'HUGGINGFACE_TOKEN' in st.secrets:
        login(st.secrets['HUGGINGFACE_TOKEN'])
        return True
    else:
        st.error("No se encontr贸 el token de Hugging Face en los secrets")
        st.stop()
        return False

class LlamaDemo:
    def __init__(self):
        self.model_name = "meta-llama/Llama-2-7b-chat"
        self._model = None
        self._tokenizer = None
        
    @property
    def model(self):
        if self._model is None:
            self._model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16,  # Usar float16 para optimizar memoria
                device_map="auto",
                load_in_8bit=True  # Cuantizaci贸n 8-bit para optimizar memoria
            )
        return self._model
    
    @property
    def tokenizer(self):
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )
        return self._tokenizer

    def generate_response(self, prompt: str, max_new_tokens: int = 512) -> str:
        formatted_prompt = f"[INST] {prompt} [/INST]"
        
        inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                num_return_sequences=1,
                temperature=0.7,
                do_sample=True,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        # Liberar memoria GPU despu茅s de generar
        torch.cuda.empty_cache()
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response.split("[/INST]")[-1].strip()

def main():
    st.set_page_config(
        page_title="Llama 2 Chat Demo",
        page_icon="馃",
        layout="wide"
    )
    
    st.title("馃 Llama 2 Chat Demo")
    
    # Mostrar informaci贸n de GPU
    gpu_info = check_gpu()
    with st.expander("馃捇 GPU Info", expanded=False):
        for key, value in gpu_info.items():
            st.write(f"{key}: {value}")
    
    # Initialize model
    if 'llama' not in st.session_state:
        with st.spinner("Loading Llama 2... This might take a few minutes..."):
            st.session_state.llama = LlamaDemo()
    
    if 'chat_history' not in st.session_state:
        st.session_state.chat_history = []
        
    # Chat interface
    with st.container():
        for message in st.session_state.chat_history:
            with st.chat_message(message["role"]):
                st.write(message["content"])
    
        if prompt := st.chat_input("What would you like to discuss?"):
            st.session_state.chat_history.append({
                "role": "user",
                "content": prompt
            })
            
            with st.chat_message("user"):
                st.write(prompt)
            
            with st.chat_message("assistant"):
                with st.spinner("Thinking..."):
                    try:
                        response = st.session_state.llama.generate_response(prompt)
                        st.write(response)
                        st.session_state.chat_history.append({
                            "role": "assistant",
                            "content": response
                        })
                    except Exception as e:
                        st.error(f"Error: {str(e)}")
    
    with st.sidebar:
        st.markdown("""
        ### Memory Management
        To optimize GPU usage and costs:
        - Model runs in 8-bit precision
        - Memory is cleared after each generation
        - Space sleeps after inactivity
        """)
        
        if st.button("Clear Chat History"):
            st.session_state.chat_history = []
            st.experimental_rerun()

if __name__ == "__main__":
    main()