WICKED4950 commited on
Commit
4b69c94
·
verified ·
1 Parent(s): 0f89ae0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -2,19 +2,34 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  print("Loading the model......")
4
  model_name = "WICKED4950/Irisonego5"
 
 
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
7
 
8
  print("Interface getting done....")
9
  # Define the chatbot function
10
- def chatbot(input_text):
11
- inputs = tokenizer(input_text, return_tensors="pt")
12
- outputs = model.generate(inputs["input_ids"], max_length=100)
13
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
14
  return response
15
 
16
  # Gradio interface
17
- iface = gr.Interface(fn=chatbot,
18
  inputs="text",
19
  outputs="text",
20
  title="Your Chatbot")
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  print("Loading the model......")
4
  model_name = "WICKED4950/Irisonego5"
5
+ strategy = tf.distribute.MirroredStrategy()
6
+ tf.config.optimizer.set_jit(True) # Enable XLA
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ with strategy.scope():
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
  print("Interface getting done....")
12
  # Define the chatbot function
13
+ def predict(user_input):
14
+ # Tokenize input text
15
+ inputs = tokenizer(user_input, return_tensors="tf", padding=True, truncation=True)
16
+
17
+ # Generate the response using the model
18
+ response_ids = model.generate(
19
+ inputs['input_ids'],
20
+ max_length=128, # Set max length of response
21
+ do_sample=True, # Sampling for variability
22
+ top_k=15, # Consider top 50 tokens
23
+ top_p=0.95, # Nucleus sampling
24
+ temperature=0.8 # Adjusts creativity of response
25
+ )
26
+
27
+ # Decode the response
28
+ response = tokenizer.decode(response_id[0], skip_special_tokens=True)
29
  return response
30
 
31
  # Gradio interface
32
+ iface = gr.Interface(fn=predict,
33
  inputs="text",
34
  outputs="text",
35
  title="Your Chatbot")