johann22 commited on
Commit
640f959
·
verified ·
1 Parent(s): f63f033

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import logging
4
+ import threading
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ # --- Setup Logging ---
8
+ logging.basicConfig(level=logging.INFO)
9
+ log = logging.getLogger(__name__)
10
+
11
+ # --- Global SentenceTransformer Model ---
12
+ # This pattern ensures the model is loaded only once on the first request.
13
+ _embedder_instance = None
14
+ _embedder_lock = threading.Lock()
15
+ MODEL_NAME = 'all-MiniLM-L6-v2'
16
+
17
+ def _get_embedder():
18
+ """
19
+ Lazily and thread-safely initializes and returns the SentenceTransformer embedder.
20
+ """
21
+ global _embedder_instance
22
+ # Use a double-checked lock to avoid acquiring the lock for every request
23
+ if _embedder_instance is None:
24
+ with _embedder_lock:
25
+ # Check again inside the lock to ensure it wasn't initialized by another thread
26
+ # while the current thread was waiting for the lock.
27
+ if _embedder_instance is None:
28
+ try:
29
+ log.info(f"Loading SentenceTransformer model: {MODEL_NAME} (lazy init)...")
30
+ _embedder_instance = SentenceTransformer(MODEL_NAME)
31
+ log.info("SentenceTransformer model loaded successfully.")
32
+ except Exception as e:
33
+ log.critical(f"Failed to load SentenceTransformer model: {e}", exc_info=True)
34
+ # The instance remains None, so subsequent calls will retry.
35
+ _embedder_instance = None
36
+ return _embedder_instance
37
+
38
+ def generate_embeddings(texts: list[str]) -> dict:
39
+ """
40
+ Generates embeddings for a list of input texts.
41
+
42
+ Args:
43
+ texts: A list of strings to be embedded.
44
+
45
+ Returns:
46
+ A dictionary containing the list of embedding vectors or an error message.
47
+ """
48
+ if not isinstance(texts, list) or not all(isinstance(t, str) for t in texts):
49
+ # Gradio's JSON component will likely parse it correctly, but this is a good safeguard.
50
+ log.error("Invalid input: 'texts' must be a list of strings.")
51
+ return {"error": "Invalid input format. Expected a list of strings."}
52
+
53
+ embedder = _get_embedder()
54
+ if embedder is None:
55
+ log.error("Embedder not available. Cannot generate embeddings.")
56
+ # We don't return a 500 error here so the client can see the message.
57
+ # In a real production system, you might raise an exception to trigger a 500.
58
+ return {"error": "Embedding model is not available. Please check the server logs."}
59
+
60
+ try:
61
+ log.info(f"Generating embeddings for {len(texts)} text(s).")
62
+ # The encode function is thread-safe.
63
+ embeddings = embedder.encode(texts, convert_to_numpy=True).tolist()
64
+ log.info("Embeddings generated successfully.")
65
+ return {"embeddings": embeddings}
66
+ except Exception as e:
67
+ log.error(f"An error occurred during embedding generation: {e}", exc_info=True)
68
+ return {"error": f"An unexpected error occurred: {e}"}
69
+
70
+ # --- Create the Gradio Interface ---
71
+ # We use gr.JSON for both input and output for maximum flexibility and API-friendliness.
72
+ description = """
73
+ ### Sentence Embedding API
74
+ This API provides access to the `all-MiniLM-L6-v2` sentence embedding model.
75
+
76
+ **How to use the API:**
77
+ 1. Send a POST request to the `/api/generate_embeddings/` endpoint.
78
+ 2. The body of the request should be a JSON object with a "data" key.
79
+ 3. The value of "data" should be an array containing one element: a list of the texts you want to embed.
80
+
81
+ **Example using `curl`:**
82
+
83
+ curl -X POST "https://YOUR-SPACE-NAME.hf.space/api/generate_embeddings/" \\
84
+ -H "Content-Type: application/json" \\
85
+ -d '{"data": [["Hello, world!", "This is another sentence."]]}'
86
+
87
+ **Expected Success Response (JSON):**
88
+
89
+ {
90
+ "data": [
91
+ {
92
+ "embeddings": [
93
+ [-0.0139..., 0.0523..., ..., -0.0111...],
94
+ [0.0229..., -0.0149..., ..., 0.0515...]
95
+ ]
96
+ }
97
+ ],
98
+ "is_generating": false,
99
+ "duration": 0.5,
100
+ "average_duration": 0.5
101
+ }
102
+ """
103
+
104
+ demo = gr.Interface(
105
+ fn=generate_embeddings,
106
+ inputs=gr.JSON(
107
+ label="Input Texts",
108
+ info='Provide a list of strings, e.g., ["text 1", "text 2"]'
109
+ ),
110
+ outputs=gr.JSON(label="Output Embeddings"),
111
+ title="Sentence Embedding API Service",
112
+ description=description,
113
+ examples=[
114
+ [[["Hello world", "Gradio is a great tool for building ML apps."]]],
115
+ [[["What is the capital of France?"]]]
116
+ ],
117
+ api_name="generate_embeddings" # This creates the /api/generate_embeddings/ endpoint
118
+ )
119
+
120
+ if __name__ == "__main__":
121
+ demo.launch()