telegram-file-to-link-bot/bot/handlers/upload.py

179 lines
4.7 KiB
Python

# Copyright 2025 Aman
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import asyncio
import uuid
import os
import re
from datetime import datetime, timedelta, timezone
import boto3
from pyrogram import filters
from bot.bot import tg_client
from cache.redis import redis_client
from bot.utils.access import is_allowed
from bot.utils.mode import get_mode, format_ttl
from config import (
BASE_URL,
MAX_FILE_MB,
MAX_CONCURRENT_TRANSFERS,
STORAGE_BACKEND,
AWS_ENDPOINT_URL,
AWS_S3_BUCKET_NAME,
AWS_DEFAULT_REGION,
)
from db.database import Database
UPLOAD_DIR = os.path.abspath("uploads")
os.makedirs(UPLOAD_DIR, exist_ok=True)
upload_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TRANSFERS)
s3 = None
if STORAGE_BACKEND == "s3":
s3 = boto3.client(
"s3",
endpoint_url=AWS_ENDPOINT_URL,
region_name=AWS_DEFAULT_REGION,
)
def safe_filename(name: str) -> str:
return re.sub(r'[<>:"/\\|?*\x00-\x1F]', "_", name).strip()
@tg_client.on_message(
filters.private & (
filters.document
| filters.video
| filters.audio
| filters.photo
| filters.animation
| filters.voice
| filters.video_note
)
)
async def upload_handler(_, message):
if not message.from_user:
return
if not is_allowed(message.from_user.id):
await message.reply("🚫 Unauthorized")
return
status = await message.reply("📥 Queued for processing…")
async with upload_semaphore:
await process_upload(message, status)
async def process_upload(message, status):
media = (
message.document
or message.video
or message.audio
or message.photo
or message.animation
or message.voice
or message.video_note
)
file_size = getattr(media, "file_size", None)
if MAX_FILE_MB is not None and file_size:
max_bytes = MAX_FILE_MB * 1024 * 1024
if file_size > max_bytes:
size_mb = file_size / (1024 * 1024)
await status.edit(
"❌ **File too large**\n\n"
f"Your file: **{size_mb:.2f} MB**\n"
f"Max allowed: **{MAX_FILE_MB} MB**"
)
return
await status.edit("⬇️ Downloading…")
temp_path = await message.download()
if not temp_path:
await status.edit("❌ Download failed")
return
if message.photo:
original_name = f"{uuid.uuid4().hex}.jpg"
elif hasattr(media, "file_name") and media.file_name:
original_name = safe_filename(media.file_name)
else:
original_name = f"{uuid.uuid4().hex}.bin"
file_size = file_size or os.path.getsize(temp_path)
file_id = uuid.uuid4().hex[:12]
ext = os.path.splitext(original_name)[1]
if STORAGE_BACKEND == "local":
internal_path = os.path.join(UPLOAD_DIR, f"{file_id}{ext}")
os.replace(temp_path, internal_path)
stored_path = internal_path
else:
key = f"{file_id}{ext}"
s3.upload_file(temp_path, AWS_S3_BUCKET_NAME, key)
os.remove(temp_path)
stored_path = key
user_mode = get_mode(message.from_user.id)
if user_mode["ttl"] > 0:
ttl = user_mode["ttl"]
ttl_source = "👤 Using your TTL"
else:
ttl = 0
ttl_source = "♾ No expiration"
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=ttl)
if ttl > 0 else None
)
await Database.pool.execute(
"""
INSERT INTO files (
file_id, path, name, downloads, file_size, expires_at
)
VALUES ($1, $2, $3, 0, $4, $5)
""",
file_id,
stored_path,
original_name,
file_size,
expires_at,
)
redis_client.delete(f"file:{file_id}")
redis_client.hset(
f"file:{file_id}",
mapping={
"path": stored_path,
"name": original_name,
"downloads": 0,
"file_size": file_size,
"expires_at": int(expires_at.timestamp()) if expires_at else 0,
}
)
size_mb = file_size / (1024 * 1024)
await status.edit(
"✅ **File uploaded**\n\n"
f"{ttl_source}\n"
f"📄 **Name:** `{original_name}`\n"
f"📦 **Size:** `{size_mb:.2f} MB`\n"
f"⏳ **Expires:** {format_ttl(ttl)}\n\n"
f"🔗 `{BASE_URL}/file/{file_id}`"
)