ans123 commited on
Commit
03620de
·
verified ·
1 Parent(s): 38cd508

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -1,12 +1,16 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
-
6
- # Load the model and tokenizer
7
- model_name = "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
8
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
10
 
11
  # Define the system message for the model
12
  system_message = (
@@ -52,13 +56,10 @@ def chat(user_input, messages):
52
  # Prepare the input for the model
53
  input_text = system_message + "\n" + "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
54
 
55
- # Tokenize and encode the input text
56
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
57
-
58
  try:
59
- # Generate a response from the model
60
- outputs = model.generate(**inputs, max_length=150, num_return_sequences=1, temperature=0.7)
61
- response_content = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
 
63
  except Exception as e:
64
  response_content = f"Error: {str(e)}"
 
1
  import gradio as gr
2
  import pandas as pd
3
  import torch
4
+ from transformers import pipeline
5
+
6
+ # Load the model pipeline
7
+ model_id = "meta-llama/Llama-3.2-1B"
8
+ pipe = pipeline(
9
+ "text-generation",
10
+ model=model_id,
11
+ torch_dtype=torch.bfloat16,
12
+ device_map="auto"
13
+ )
14
 
15
  # Define the system message for the model
16
  system_message = (
 
56
  # Prepare the input for the model
57
  input_text = system_message + "\n" + "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
58
 
 
 
 
59
  try:
60
+ # Generate a response using the pipeline
61
+ response = pipe(input_text, max_length=150, num_return_sequences=1, temperature=0.7)
62
+ response_content = response[0]['generated_text'].split('\n')[-1].strip() # Extract the last line of the generated text
63
 
64
  except Exception as e:
65
  response_content = f"Error: {str(e)}"