astro21 commited on
Commit
64c1f09
·
verified ·
1 Parent(s): 4e4d7b0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -9
main.py CHANGED
@@ -1,16 +1,16 @@
1
- from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
  import transformers
4
- import torch
5
  from fastapi.middleware.cors import CORSMiddleware
6
-
7
-
8
  import os
 
 
 
9
  access_token_read = os.getenv('DS4')
10
  print(access_token_read)
11
 
12
- from huggingface_hub import login
13
- login(token = access_token_read)
14
 
15
  # Define the FastAPI app
16
  app = FastAPI()
@@ -22,19 +22,22 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Load the model and tokenizer from Hugging Face
26
  model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Replace with an appropriate model
27
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
28
  model = transformers.AutoModelForCausalLM.from_pretrained(
29
- model_id, device_map="auto", torch_dtype=torch.bfloat16
 
30
  )
 
 
31
  pipeline = transformers.pipeline(
32
  "text-generation",
33
  model=model,
34
  tokenizer=tokenizer,
35
  max_new_tokens=150,
36
  temperature=0.7,
37
- device_map="auto",
38
  )
39
 
40
  # Define the request model for email input
@@ -44,6 +47,11 @@ class EmailRequest(BaseModel):
44
  recipients: str
45
  body: str
46
 
 
 
 
 
 
47
  # Define the FastAPI endpoint for email summarization
48
  @app.post("/summarize-email/")
49
  async def summarize_email(email: EmailRequest):
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  import transformers
 
4
  from fastapi.middleware.cors import CORSMiddleware
 
 
5
  import os
6
+ from huggingface_hub import login
7
+
8
+ # Get access token from environment variable
9
  access_token_read = os.getenv('DS4')
10
  print(access_token_read)
11
 
12
+ # Login to Hugging Face Hub
13
+ login(token=access_token_read)
14
 
15
  # Define the FastAPI app
16
  app = FastAPI()
 
22
  allow_headers=["*"],
23
  )
24
 
25
+ # Load the model and tokenizer from Hugging Face, set device to CPU
26
  model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Replace with an appropriate model
27
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
28
  model = transformers.AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ # Removed device_map and low_cpu_mem_usage to avoid the need for 'accelerate'
31
  )
32
+
33
+ # Set up the text generation pipeline for CPU
34
  pipeline = transformers.pipeline(
35
  "text-generation",
36
  model=model,
37
  tokenizer=tokenizer,
38
  max_new_tokens=150,
39
  temperature=0.7,
40
+ device=-1 # Force CPU usage
41
  )
42
 
43
  # Define the request model for email input
 
47
  recipients: str
48
  body: str
49
 
50
+ # Helper function to create the email prompt
51
+ def create_email_prompt(subject, sender, recipients, body):
52
+ prompt = f"Subject: {subject}\nFrom: {sender}\nTo: {recipients}\n\n{body}\n\nSummarize this email."
53
+ return prompt
54
+
55
  # Define the FastAPI endpoint for email summarization
56
  @app.post("/summarize-email/")
57
  async def summarize_email(email: EmailRequest):