Darshan commited on
Commit
ad152ab
·
1 Parent(s): bae6852

use different app for testing

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -3
  2. app.py +16 -23
Dockerfile CHANGED
@@ -2,10 +2,10 @@
2
  FROM python:3.10.9
3
 
4
  # Copy the current directory contents into the container at .
5
- COPY ./app ./app
6
 
7
  # Set the working directory to /
8
- WORKDIR /trans
9
 
10
  EXPOSE 7860
11
 
@@ -13,4 +13,4 @@ EXPOSE 7860
13
  RUN pip install --no-cache-dir --upgrade -r /requirements.txt
14
 
15
  # Start the FastAPI app on port 7860, the default port expected by Spaces
16
- CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
 
2
  FROM python:3.10.9
3
 
4
  # Copy the current directory contents into the container at .
5
+ COPY . .
6
 
7
  # Set the working directory to /
8
+ WORKDIR /
9
 
10
  EXPOSE 7860
11
 
 
13
  RUN pip install --no-cache-dir --upgrade -r /requirements.txt
14
 
15
  # Start the FastAPI app on port 7860, the default port expected by Spaces
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,17 +1,12 @@
1
- from fastapi import FastAPI
2
  from typing import List
3
- import torch
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
  from IndicTransToolkit import IndicProcessor
6
  from fastapi.middleware.cors import CORSMiddleware
7
 
8
- import os
9
-
10
- os.environ["HF_HOME"] = "/.cache"
11
- # Initialize FastAPI
12
  app = FastAPI()
13
 
14
- # Add CORS middleware
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -20,13 +15,13 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
23
- # Initialize models and processors
24
  model = AutoModelForSeq2SeqLM.from_pretrained(
25
  "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
26
  )
27
  tokenizer = AutoTokenizer.from_pretrained(
28
  "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
29
  )
 
30
  ip = IndicProcessor(inference=True)
31
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
  model = model.to(DEVICE)
@@ -58,29 +53,27 @@ def translate_text(sentences: List[str], target_lang: str):
58
  generated_tokens = tokenizer.batch_decode(
59
  generated_tokens.detach().cpu().tolist(),
60
  skip_special_tokens=True,
61
- clean_up_tokenization_spaces=True,
62
  )
63
 
64
- translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
65
- return {
66
- "translations": translations,
67
- "source_language": src_lang,
68
- "target_language": target_lang,
69
- }
70
  except Exception as e:
71
- raise Exception(f"Translation failed: {str(e)}")
 
 
 
 
 
72
 
73
 
74
- # FastAPI routes
75
- @app.get("/health")
76
- async def health_check():
77
- return {"status": "healthy"}
78
 
79
 
80
- @app.post("/translate")
81
- async def translate_endpoint(sentences: List[str], target_lang: str):
82
  try:
83
- result = translate_text(sentences=sentences, target_lang=target_lang)
84
  return result
85
  except Exception as e:
86
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ from fastapi import FastAPI, HTTPException
2
  from typing import List
3
+ from pydantic import BaseModel
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
  from IndicTransToolkit import IndicProcessor
6
  from fastapi.middleware.cors import CORSMiddleware
7
 
 
 
 
 
8
  app = FastAPI()
9
 
 
10
  app.add_middleware(
11
  CORSMiddleware,
12
  allow_origins=["*"],
 
15
  allow_headers=["*"],
16
  )
17
 
 
18
  model = AutoModelForSeq2SeqLM.from_pretrained(
19
  "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
20
  )
21
  tokenizer = AutoTokenizer.from_pretrained(
22
  "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
23
  )
24
+
25
  ip = IndicProcessor(inference=True)
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
  model = model.to(DEVICE)
 
53
  generated_tokens = tokenizer.batch_decode(
54
  generated_tokens.detach().cpu().tolist(),
55
  skip_special_tokens=True,
 
56
  )
57
 
58
+ return generated_tokens
 
 
 
 
 
59
  except Exception as e:
60
+ return str(e)
61
+
62
+
63
+ @app.get("/")
64
+ def read_root():
65
+ return {"Hello": "World"}
66
 
67
 
68
+ class TranslateRequest(BaseModel):
69
+ sentences: List[str]
70
+ target_lang: str
 
71
 
72
 
73
+ @app.post("/translate/")
74
+ def translate(request: TranslateRequest):
75
  try:
76
+ result = translate_text(request.sentences, request.target_lang)
77
  return result
78
  except Exception as e:
79
  raise HTTPException(status_code=500, detail=str(e))