codeblacks commited on
Commit
18feec4
·
verified ·
1 Parent(s): a7357eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -24
app.py CHANGED
@@ -1,36 +1,22 @@
1
  from sentence_transformers import SentenceTransformer
2
  import gradio as gr
3
  import torch
4
- import numpy as np
5
 
6
  # Load the pre-trained model
7
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
8
 
9
- # Define the function to process requests
10
- def generate_embeddings(text):
11
- # Split the input text into chunks (if needed)
12
- chunks = text.split('\n') # Assuming chunks are separated by new lines
13
-
14
- # Encode the input chunks to get the embeddings
15
- embeddings = embedding_model.encode(chunks, convert_to_tensor=False)
16
-
17
- # Convert the embeddings to a PyTorch tensor
18
- embeddings_tensor = torch.tensor(embeddings)
19
-
20
- # Add batch dimension to the tensor (if needed)
21
- embeddings_tensor = embeddings_tensor.unsqueeze(0) # Uncomment if a batch dimension is required
22
-
23
- # Return the embeddings tensor and its shape
24
- return embeddings_tensor.tolist(), embeddings_tensor.shape
25
 
26
  # Define the Gradio interface
27
  interface = gr.Interface(
28
- fn=generate_embeddings,
29
- inputs=gr.Textbox(lines=5, placeholder="Enter text chunks here..."),
30
- outputs=[gr.JSON(label="Embeddings"), gr.Label(label="Shape")],
31
- title="Sentence Transformer Embeddings",
32
- description="Generate embeddings for input text chunks."
33
  )
34
 
35
- # Launch the Gradio app
36
- interface.launch()
 
1
  from sentence_transformers import SentenceTransformer
2
  import gradio as gr
3
  import torch
 
4
 
5
  # Load the pre-trained model
6
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
7
 
8
+ def get_embeddings(sentences):
9
+ embeddings = model.encode(sentences, convert_to_tensor=True)
10
+ return embeddings.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Define the Gradio interface
13
  interface = gr.Interface(
14
+ fn=get_embeddings, # Function to call
15
+ inputs=gr.Textbox(lines=2, placeholder="Enter sentences here, one per line"), # Input component
16
+ outputs=gr.Image(label="Embeddings", image_formatter=plot_embeddings)
17
+ title="Sentence Embeddings", # Interface title
18
+ description="Enter sentences to get their embeddings." # Description
19
  )
20
 
21
+ # Launch the interface
22
+ interface.launch()