ndwdgda commited on
Commit
08ab954
·
verified ·
1 Parent(s): 14f6a76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -34
app.py CHANGED
@@ -1,44 +1,70 @@
1
- from transformers import pipeline
2
- from gradio import gr
3
- from huggingface_hub import InferenceClient
 
4
 
5
- pipe = pipeline("fill-mask", model="google-bert/bert-base-uncased")
 
 
6
 
7
- def respond(user_input, history, system_message, max_tokens, temperature, top_p):
8
- messages = [{"role": "system", "content": system_message}]
 
9
 
10
- for val in history:
11
- if val["role"] == "user":
12
- messages.append(val)
13
- if val["role"] == "assistant":
14
- messages.append(val)
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
16
  messages.append({"role": "user", "content": user_input})
17
 
18
- response = ""
19
- for message in client.chat_completion(
20
- messages,
21
- max_tokens=max_tokens,
22
- stream=True,
 
 
 
 
 
 
 
 
 
23
  temperature=temperature,
24
  top_p=top_p,
25
- ):
26
- token = message.choices[0].delta.content
27
- response += token
28
- yield response
29
-
30
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
31
-
32
- demo = gr.ChatInterface(
33
- respond,
34
- inputs=[
35
- gr.Textbox(placeholder="Type your message here", label="User input"),
36
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
37
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
38
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
39
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
40
- ],
41
- )
42
 
43
  if __name__ == "__main__":
44
- demo.launch()
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ # Set up logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
 
10
+ # Checkpoint paths
11
+ model_checkpoint_path = "model_checkpoint.pth"
12
+ tokenizer_checkpoint_path = "tokenizer_checkpoint.pth"
13
 
14
+ # Load model and tokenizer from checkpoint if they exist
15
+ if os.path.exists(model_checkpoint_path) and os.path.exists(tokenizer_checkpoint_path):
16
+ try:
17
+ model = torch.load(model_checkpoint_path)
18
+ tokenizer = torch.load(tokenizer_checkpoint_path)
19
+ logger.info("Model and tokenizer loaded from checkpoint.")
20
+ except Exception as e:
21
+ logger.error(f"Failed to load model or tokenizer from checkpoint: {e}")
22
+ raise
23
+ else:
24
+ # Load model directly
25
+ try:
26
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
27
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b")
28
+ logger.info("Model and tokenizer loaded successfully.")
29
+ except Exception as e:
30
+ logger.error(f"Failed to load model or tokenizer: {e}")
31
+ raise
32
 
33
+ def respond(user_input, history, system_message, max_tokens=20, temperature=0.9, top_p=0.9):
34
+ messages = [{"role": "system", "content": system_message}]
35
+ messages.extend(history)
36
  messages.append({"role": "user", "content": user_input})
37
 
38
+ # Convert messages to a single string
39
+ input_text = " ".join([msg["content"] for msg in messages])
40
+
41
+ # Tokenize the input text
42
+ inputs = tokenizer(input_text, return_tensors="pt")
43
+
44
+ # Generate attention mask
45
+ attention_mask = inputs["attention_mask"]
46
+
47
+ # Generate text using the model
48
+ outputs = model.generate(
49
+ inputs.input_ids,
50
+ attention_mask=attention_mask,
51
+ max_length=max_tokens,
52
  temperature=temperature,
53
  top_p=top_p,
54
+ pad_token_id=tokenizer.eos_token_id,
55
+ do_sample=True
56
+ )
57
+
58
+ # Decode the generated text
59
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+
61
+ return response
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
+ print("Welcome to the Chatbot!")
65
+ while True:
66
+ user_input = input("You: ")
67
+ system_message = "Chatbot: "
68
+ history = [{"role": "assistant", "content": "Hello, how can I assist you today?"}]
69
+ response = respond(user_input, history, system_message)
70
+ print(response)