Abhijit-192-168-1-1 commited on
Commit
c51e482
·
1 Parent(s): a05b3ab

added app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from llm2vec import LLM2Vec
3
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
4
+ from peft import PeftModel
5
+ import torch
6
+ import os
7
+
8
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
9
+ torch.backends.cuda.enable_flash_sdp(False)
10
+
11
+ # Read tokens from environment variables
12
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
13
+ HF_TOKEN = os.getenv('HF_TOKEN')
14
+
15
+ if not GROQ_API_KEY or not HF_TOKEN:
16
+ raise ValueError("GROQ_API_KEY and HF_TOKEN must be set as environment variables.")
17
+
18
+ os.environ['GROQ_API_KEY'] = GROQ_API_KEY
19
+ os.environ['HF_TOKEN'] = HF_TOKEN
20
+
21
+
22
+ # Load tokenizer and model
23
+ tokenizer = AutoTokenizer.from_pretrained("McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp")
24
+ config = AutoConfig.from_pretrained("McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", trust_remote_code=True)
25
+ model = AutoModel.from_pretrained("McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", trust_remote_code=True, config=config, torch_dtype=torch.bfloat16, device_map="cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ model = PeftModel.from_pretrained(model, "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp")
28
+ model = model.merge_and_unload()
29
+
30
+ # Load unsupervised SimCSE model
31
+ model = PeftModel.from_pretrained(model, "McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse")
32
+
33
+ # Wrapper for encoding and pooling operations
34
+ l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=512)
35
+
36
+ def encode_text(input_text):
37
+ encoding = l2v.encode(input_text)
38
+ return encoding
39
+
40
+ # Define Gradio interface
41
+ iface = gr.Interface(
42
+ fn=encode_text,
43
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Enter text here..."),
44
+ outputs=gr.outputs.JSON()
45
+ )
46
+
47
+ # Launch Gradio app
48
+ iface.launch(share=True)