#encoding utf-8

__author__ = "Roland Trouville"
__copyright__ = "Copyright 2021+, LPP"
__license__ = "Creative Commons 4.0 By-Nc-Sa"
__maintainer__ = "Roland Trouville"
__email__ = "roland.trouville@sorbonne-nouvelle.fr"
__status__ = "Production"

import os
import cv2
from PyQt5.QtGui import QImage, QPixmap
from tools.data_manager import DataManager
from PyQt5.QtWidgets import QProgressBar, QApplication


class VideoManager(object):
	MAX_FRAMES_IN_MEMORY = 500
	FRAME_IMG_EXTENSION = ".jpg"

	fname = None
	framerate = None
	last_frame_id = None
	cv2h = None
	frame_timecodes = None
	frames = None
	basepath = None
	one_frame_duration = None
	duration = None
	cutoff = None

	pbar = None

	def __init__(self, fname:str, basepath:str, pbar:QProgressBar=None):
		self.fname = fname
		self.pbar = pbar
		self.basepath = basepath
		self.cv2h = cv2.VideoCapture(self.fname)
		self.framerate = float(self.cv2h.get(cv2.CAP_PROP_FPS))
		self.last_frame_id = int(self.cv2h.get(cv2.CAP_PROP_FRAME_COUNT))-1
		self.one_frame_duration = 1 / self.framerate
		self.duration = (self.last_frame_id+1) * self.one_frame_duration
		DataManager.log("Video File: %f FPS, %d frames, %0.4f seconds per frame, %0.4f seconds total" % (self.framerate, self.last_frame_id + 1, self.one_frame_duration, self.duration))
		if not os.path.isfile(self.basepath + "frames/0"+VideoManager.FRAME_IMG_EXTENSION):
			self.extract_frames_to_disk()
		else:
			self.check_last_frame_id()
		self.cutoff = (self.last_frame_id+1) * self.one_frame_duration
		DataManager.log("Video File: %d frames found. Cutoff time set to %0.4f seconds" % (self.last_frame_id+1, self.cutoff))



	def get_frame_id_at_tc(self, timecode:float):
		if self.frame_timecodes is None:
			self.prepare_tcs()
		for i in range(self.last_frame_id):
			if self.frame_timecodes[i] <= timecode < self.frame_timecodes[i + 1]:
				return i
		return self.last_frame_id

	def get_frame_at_tc(self, timecode:float):
		id = self.get_frame_id_at_tc(timecode)
		return self.get_frame(id)

	def get_frame(self,frame_id : int):
		if self.frames is None:
			self.frames = {}
		if len(self.frames) > VideoManager.MAX_FRAMES_IN_MEMORY:
			self.frames = {}
		if frame_id not in self.frames:
			if os.path.isfile(self.basepath+"frames/"+str(frame_id)+VideoManager.FRAME_IMG_EXTENSION):
				self.frames[frame_id] = QPixmap(self.basepath+"frames/"+str(frame_id)+VideoManager.FRAME_IMG_EXTENSION)
			else:
				raise Exception("No frame found for frame id "+str(frame_id))
		return self.frames[frame_id]

	def get_tc_for_frame(self, frame_id: int):
		if self.frame_timecodes is None:
			self.prepare_tcs()
		try:
			return self.frame_timecodes[frame_id]
		except KeyError:
			return 0

	def check_last_frame_id(self):
		for i in range(self.last_frame_id, 0, -1):
			if os.path.isfile(self.basepath+"frames/"+str(i)+VideoManager.FRAME_IMG_EXTENSION):
				if self.last_frame_id > i:
					DataManager.log("Video File: Changed last frame id from %d (announced by file) to %d (found in frames/ dir)" % (self.last_frame_id, i ))
					self.last_frame_id = i
				return

	def prepare_tcs(self):
		self.frame_timecodes = {}
		tc = 0
		for i in range(self.last_frame_id+1):
			self.frame_timecodes[i] = tc
			tc += self.one_frame_duration

	def extract_frames_to_disk(self):
		frame_id = 0
		if self.pbar is not None:
			self.pbar.setMaximum(self.last_frame_id)
			self.pbar.setValue(frame_id)
		while frame_id <= self.last_frame_id:
			ret, frame = self.cv2h.read()
			if not ret:
				print("frame " + str(frame_id) + " not ok")
				break
			cv2.imwrite(self.basepath+"frames/"+"%d"%frame_id+VideoManager.FRAME_IMG_EXTENSION, frame)
			frame_id += 1
			if self.pbar is not None:
				self.pbar.setValue(frame_id)
				self.pbar.update()
				QApplication.processEvents()

		if self.last_frame_id > frame_id-1:
			DataManager.log("Video File: File announced %d frames, only %d found and saved in frames/ dir" % (self.last_frame_id+1, frame_id-1))
			self.last_frame_id = frame_id-1




