khoatran94 commited on
Commit
b2a9bc3
·
1 Parent(s): 725770a
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  import gradio as gr
8
  from prepare import prepare
9
 
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
11
  from langchain_community.llms import HuggingFacePipeline
12
  from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
13
  from langchain_core.output_parsers import StrOutputParser
@@ -51,17 +51,15 @@ def read_pdf(file_path):
51
  @spaces.GPU
52
  def query_huggingface(text):
53
  print(zero.device)
54
- pipe = pipeline(
55
- "text-generation",
56
- model="google/gemma-2-9b-it",
57
- model_kwargs={"torch_dtype": torch.bfloat16},
58
- device="cuda", # replace with "mps" to run on a Mac device
59
- )
60
- messages = [
61
- {"role": "user", "content": text},
62
- ]
63
- outputs = pipe(messages, max_new_tokens=256)
64
- return outputs[0]["generated_text"][-1]["content"].strip()
65
 
66
  # Gradio Interface for PDF Processing
67
  def process_file(file, query):
 
7
  import gradio as gr
8
  from prepare import prepare
9
 
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
11
  from langchain_community.llms import HuggingFacePipeline
12
  from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
13
  from langchain_core.output_parsers import StrOutputParser
 
51
  @spaces.GPU
52
  def query_huggingface(text):
53
  print(zero.device)
54
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ "google/gemma-2-9b-it",
57
+ device_map="auto",
58
+ torch_dtype=torch.bfloat16,
59
+ )
60
+ input_ids = tokenizer(text, return_tensors="pt").to("cuda")
61
+ outputs = model.generate(**input_ids, max_new_tokens=32)
62
+ return tokenizer.decode(outputs[0])
 
 
63
 
64
  # Gradio Interface for PDF Processing
65
  def process_file(file, query):