Neepurna commited on
Commit
388749d
·
1 Parent(s): 22dea70
Files changed (4) hide show
  1. Dockerfile +22 -0
  2. app/__init__.py +0 -0
  3. app/main.py +42 -0
  4. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python image as the base
2
+ FROM python:3.9-slim
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1
7
+
8
+ # Create a working directory
9
+ WORKDIR /app
10
+
11
+ # Copy requirements and install dependencies
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy the application code
16
+ COPY . .
17
+
18
+ # Expose the port FastAPI will run on
19
+ EXPOSE 8000
20
+
21
+ # Command to run the application
22
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
app/__init__.py ADDED
File without changes
app/main.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer, util
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import torch
6
+
7
+ app = FastAPI()
8
+
9
+ # Load the retriever model
10
+ retriever = SentenceTransformer('all-MiniLM-L6-v2')
11
+
12
+ # Load the generator model
13
+ tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
14
+ generator = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn')
15
+
16
+ class Query(BaseModel):
17
+ question: str
18
+ documents: list[str]
19
+
20
+ @app.post("/generate")
21
+ async def generate_answer(query: Query):
22
+ if not query.documents:
23
+ raise HTTPException(status_code=400, detail="No documents provided.")
24
+
25
+ # Encode the documents and the query
26
+ doc_embeddings = retriever.encode(query.documents, convert_to_tensor=True)
27
+ query_embedding = retriever.encode(query.question, convert_to_tensor=True)
28
+
29
+ # Compute cosine similarities
30
+ similarities = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
31
+ top_doc_index = torch.argmax(similarities).item()
32
+ top_doc = query.documents[top_doc_index]
33
+
34
+ # Prepare input for the generator
35
+ input_text = f"question: {query.question} context: {top_doc}"
36
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
37
+
38
+ # Generate the answer
39
+ output_ids = generator.generate(inputs.input_ids, max_length=150, num_beams=5, early_stopping=True)
40
+ answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
41
+
42
+ return {"answer": answer}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.99.1
2
+ uvicorn[standard]==0.22.0
3
+ transformers==4.33.3
4
+ sentence-transformers==2.2.2
5
+ torch==2.0.1
6
+ requests==2.31.0